AWS SageMaker¶
- class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerScheduler(session_name: str, client: Optional[Any] = None, docker_client: Optional[DockerClient] = None)[source]¶
基础类:
DockerWorkspaceMixin,Scheduler[AWSSageMakerOpts]AWSSageMakerScheduler 是一个 TorchX 调度接口,用于 AWS SageMaker。
$ torchx run -s aws_sagemaker utils.echo --image alpine:latest --msg hello aws_batch://torchx_user/1234 $ torchx status aws_batch://torchx_user/1234 ...
身份验证是通过使用
boto3凭证处理从环境中加载的。配置选项
usage: role=ROLE,instance_type=INSTANCE_TYPE,[instance_count=INSTANCE_COUNT],[user=USER],[keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS],[volume_size=VOLUME_SIZE],[volume_kms_key=VOLUME_KMS_KEY],[max_run=MAX_RUN],[input_mode=INPUT_MODE],[output_path=OUTPUT_PATH],[output_kms_key=OUTPUT_KMS_KEY],[base_job_name=BASE_JOB_NAME],[tags=TAGS],[subnets=SUBNETS],[security_group_ids=SECURITY_GROUP_IDS],[model_uri=MODEL_URI],[model_channel_name=MODEL_CHANNEL_NAME],[metric_definitions=METRIC_DEFINITIONS],[encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC],[use_spot_instances=USE_SPOT_INSTANCES],[max_wait=MAX_WAIT],[checkpoint_s3_uri=CHECKPOINT_S3_URI],[checkpoint_local_path=CHECKPOINT_LOCAL_PATH],[debugger_hook_config=DEBUGGER_HOOK_CONFIG],[enable_sagemaker_metrics=ENABLE_SAGEMAKER_METRICS],[enable_network_isolation=ENABLE_NETWORK_ISOLATION],[disable_profiler=DISABLE_PROFILER],[environment=ENVIRONMENT],[max_retry_attempts=MAX_RETRY_ATTEMPTS],[source_dir=SOURCE_DIR],[git_config=GIT_CONFIG],[hyperparameters=HYPERPARAMETERS],[container_log_level=CONTAINER_LOG_LEVEL],[code_location=CODE_LOCATION],[dependencies=DEPENDENCIES],[training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE],[training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN],[disable_output_compression=DISABLE_OUTPUT_COMPRESSION],[enable_infra_check=ENABLE_INFRA_CHECK],[image_repo=IMAGE_REPO],[quiet=QUIET] required arguments: role=ROLE (str) an AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. instance_type=INSTANCE_TYPE (str) type of EC2 instance to use for training, for example, 'ml.c4.xlarge' optional arguments: instance_count=INSTANCE_COUNT (int, 1) number of Amazon EC2 instances to use for training. Required if instance_groups is not set. user=USER (str, ec2-user) the username to tag the job with. `getpass.getuser()` if not specified. keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS (int, None) the duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs. volume_size=VOLUME_SIZE (int, None) size in GB of the storage volume to use for storing input and output data during training (default: 30). volume_kms_key=VOLUME_KMS_KEY (str, None) KMS key ID for encrypting EBS volume attached to the training instance. max_run=MAX_RUN (int, None) timeout in seconds for training (default: 24 * 60 * 60). input_mode=INPUT_MODE (str, None) the input mode that the algorithm supports (default: ‘File’). output_path=OUTPUT_PATH (str, None) S3 location for saving the training result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the fit() method execution. output_kms_key=OUTPUT_KMS_KEY (str, None) KMS key ID for encrypting the training output (default: Your IAM role’s KMS key for Amazon S3). base_job_name=BASE_JOB_NAME (str, None) prefix for training job name when the fit() method launches. If not specified, the estimator generates a default job name based on the training image name and current timestamp. tags=TAGS (typing.List[typing.Dict[str, str]], None) list of tags for labeling a training job. subnets=SUBNETS (typing.List[str], None) list of subnet ids. If not specified training job will be created without VPC config. security_group_ids=SECURITY_GROUP_IDS (typing.List[str], None) list of security group ids. If not specified training job will be created without VPC config. model_uri=MODEL_URI (str, None) URI where a pre-trained model is stored, either locally or in S3. model_channel_name=MODEL_CHANNEL_NAME (str, None) name of the channel where ‘model_uri’ will be downloaded (default: ‘model’). metric_definitions=METRIC_DEFINITIONS (typing.List[typing.Dict[str, str]], None) list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: ‘Name’ for the name of the metric, and ‘Regex’ for the regular expression used to extract the metric from the logs. encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC (bool, None) specifies whether traffic between training containers is encrypted for the training job (default: False). use_spot_instances=USE_SPOT_INSTANCES (bool, None) specifies whether to use SageMaker Managed Spot instances for training. If enabled then the max_wait arg should also be set. max_wait=MAX_WAIT (int, None) timeout in seconds waiting for spot training job. checkpoint_s3_uri=CHECKPOINT_S3_URI (str, None) S3 URI in which to persist checkpoints that the algorithm persists (if any) during training. checkpoint_local_path=CHECKPOINT_LOCAL_PATH (str, None) local path that the algorithm writes its checkpoints to. debugger_hook_config=DEBUGGER_HOOK_CONFIG (bool, None) configuration for how debugging information is emitted with SageMaker Debugger. If not specified, a default one is created using the estimator’s output_path, unless the region does not support SageMaker Debugger. To disable SageMaker Debugger, set this parameter to False. enable_sagemaker_metrics=ENABLE_SAGEMAKER_METRICS (bool, None) enable SageMaker Metrics Time Series. enable_network_isolation=ENABLE_NETWORK_ISOLATION (bool, None) specifies whether container will run in network isolation mode (default: False). disable_profiler=DISABLE_PROFILER (bool, None) specifies whether Debugger monitoring and profiling will be disabled (default: False). environment=ENVIRONMENT (typing.Dict[str, str], None) environment variables to be set for use during training job max_retry_attempts=MAX_RETRY_ATTEMPTS (int, None) number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts. source_dir=SOURCE_DIR (str, None) absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: current working directory) git_config=GIT_CONFIG (typing.Dict[str, str], None) git configurations used for cloning files, including repo, branch, commit, 2FA_enabled, username, password, and token. hyperparameters=HYPERPARAMETERS (typing.Dict[str, str], None) dictionary containing the hyperparameters to initialize this estimator with. container_log_level=CONTAINER_LOG_LEVEL (int, None) log level to use within the container (default: logging.INFO). code_location=CODE_LOCATION (str, None) S3 prefix URI where custom code is uploaded. dependencies=DEPENDENCIES (typing.List[str], None) list of absolute or relative paths to directories with any additional libraries that should be exported to the container. training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE (str, None) specifies how SageMaker accesses the Docker image that contains the training algorithm. training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN (str, None) Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your training image is hosted. disable_output_compression=DISABLE_OUTPUT_COMPRESSION (bool, None) when set to true, Model is uploaded to Amazon S3 without compression after training finishes. enable_infra_check=ENABLE_INFRA_CHECK (bool, None) specifies whether it is running Sagemaker built-in infra check jobs. image_repo=IMAGE_REPO (str, None) (remote jobs) the image repository to use when pushing patched images, must have push access. Ex: example.com/your/container quiet=QUIET (bool, False) whether to suppress verbose output for image building. Defaults to ``False``.兼容性
功能
调度器支持
获取日志
❌
分布式作业
✔️
取消任务
✔️
描述工作
部分支持。SageMakerScheduler 将返回作业和副本状态,但不提供完整的原始 AppSpec。
工作区 / 补丁
✔️
挂载
❌
弹性
❌
- describe(app_id: str) Optional[DescribeAppResponse][source]¶
描述指定的应用程序。
- Returns:
AppDef 描述或如果应用程序不存在则为
None。
- list() List[ListAppResponse][source]¶
对于在调度程序上启动的应用程序,此API返回一个ListAppResponse对象列表,每个对象都包含应用程序ID及其状态。 注意:此API处于原型阶段,可能会发生变化。
- log_iter(app_id: str, role_name: str, k: int = 0, regex: Optional[str] = None, since: Optional[datetime] = None, until: Optional[datetime] = None, should_tail: bool = False, streams: Optional[Stream] = None) Iterable[str][source]¶
返回一个迭代器,用于遍历
k``th replica of the ``role的日志行。 当所有符合条件的日志行都被读取后,迭代器结束。如果调度程序支持基于时间的光标获取自定义时间范围内的日志行,则
since和until字段会被尊重,否则它们会被忽略。不指定since和until等同于获取所有可用的日志行。如果until为空,则迭代器的行为类似于tail -f,跟随日志输出直到作业达到终止状态。日志的确切定义取决于调度程序。一些调度程序可能将stderr或stdout视为日志,而其他调度程序可能从日志文件中读取日志。
行为和假设:
如果在一个不存在的应用程序上调用,会产生未定义行为 调用者应在调用此方法之前使用
exists(app_id)检查应用程序是否存在。不是有状态的,使用相同参数调用此方法两次会返回一个新的迭代器。之前的迭代进度会丢失。
并不总是支持日志尾部查看。并非所有调度器都支持实时日志迭代(例如,在应用程序运行时查看日志尾部)。请参考特定调度器的文档以了解迭代器的行为。
- 3.1 If the scheduler supports log-tailing, it should be controlled
通过
should_tail参数。
不保证日志的保留。当调用此方法时,底层调度程序可能已经清除了此应用程序的日志记录。如果是这样,此方法会引发任意异常。
如果
should_tail为 True,该方法仅在可访问的日志行完全耗尽且应用程序达到最终状态时引发StopIteration异常。例如,如果应用程序卡住并且不产生任何日志行,则迭代器会阻塞,直到应用程序最终被终止(无论是通过超时还是手动),此时它会引发StopIteration。如果
should_tail为 False,当没有更多日志时,该方法会引发StopIteration。不需要所有调度程序都支持。
一些调度器可能通过支持
__getitem__来支持行光标 (例如iter[50]跳转到第50条日志行)。- Whitespace is preserved, each new line should include
\n. To 支持交互式进度条,返回的行不需要包含
\n,但应在打印时不换行以正确处理\r回车符。
- Whitespace is preserved, each new line should include
- Parameters:
streams – 要选择的IO输出流。 选项之一:combined, stdout, stderr。 如果调度程序不支持所选的流,它将抛出一个ValueError。
- Returns:
一个
Iterator,用于遍历指定角色副本的日志行- Raises:
NotImplementedError – 如果调度程序不支持日志迭代
- schedule(dryrun_info: AppDryRunInfo[AWSSageMakerJob]) str[source]¶
与
submit相同,只是它接受一个AppDryRunInfo。 鼓励实现者实现此方法,而不是直接实现submit,因为submit可以通过以下方式轻松实现:dryrun_info = self.submit_dryrun(app, cfg) return schedule(dryrun_info)
- class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerJob(job_name: str, job_def: Dict[str, Any], images_to_push: Dict[str, Tuple[str, str]])[source]¶
Jobs 定义了调度作业所需的关键值。这将是 AppDryRunInfo 对象中 request 的值。
job_name: 定义在SageMaker中显示的作业名称
job_def: 定义将用于在SageMaker上调度作业的作业描述
images_to_push: 由torchx使用,用于推送到image_repo
参考¶
- torchx.schedulers.aws_sagemaker_scheduler.create_scheduler(session_name: str, **kwargs: object) AWSSageMakerScheduler[source]¶