加入收藏 | 设为首页 | 会员中心 | 我要投稿 应用网_阳江站长网 (https://www.0662zz.com/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 创业 > 模式 > 正文

云计算必备知识-基于PyTorch机器学习构建生成对抗网络

发布时间:2020-05-28 16:46:22 所属栏目:模式 来源:51cto
导读:生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相

from sagemaker.session import Session  sess = Session()  # S3 bucket for saving code and model artifacts. # Feel free to specify a different bucket here if you wish. bucket = sess.default_bucket() 

SageMaker SDK 提供了操作 Amazon S3 服务的包和类,其中 S3Downloader 类用于访问或下载 S3 里的对象,而 S3Uploader 则用于将本地文件上传至 S3。您将已经下载的数据上传至 Amazon S3,供模型训练使用。模型训练过程不要从互联网下载数据,避免通过互联网获取训练数据的产生的网络延迟,同时也规避了因直接访问互联网对模型训练可能产生的安全风险。

from sagemaker.s3 import S3Uploader as s3up  s3_data_location = s3up.upload(f"{dataroot}/QMNIST", f"s3://{bucket}/data/qmnist")  训练执行

通过 
sagemaker.getexecutionrole() 方法,当前笔记本可以得到预先分配给笔记本实例的角色,这个角色将被用来获取训练用的资源,比如下载训练用框架镜像、分配 Amazon EC2 计算资源等等。

训练模型用的超参数可以在笔记本里定义,实现与算法代码的分离,在创建训练任务时传入超参数,与训练任务动态结合。

hps = { "learning-rate": 0.0002, "epochs": 15, "dataset": "qmnist", "beta1": 0.5, "sample-interval": 200, "log-interval": 64 } 

sagemaker.pytorch 包里的 PyTorch 类是基于 PyTorch 框架的模型拟合器,可以用来创建、执行训练任务,还可以对训练完的模型进行部署。参数列表中, train_instance_type 用来指定CPU或者GPU实例类型,训练脚本和包括模型代码所在的目录通过 source_dir 指定,训练脚本文件名必须通过 entry_point 明确定义。这些参数将和其余参数一起被传递给训练任务,他们决定了训练任务的运行环境和模型训练时参数。

from sagemaker.pytorch import PyTorch  estimator = PyTorch(role=role, entry_point='train.py', source_dir='dcgan', output_path=s3_model_artifacts_location, code_location=s3_custom_code_upload_location, train_instance_count=1, train_instance_type='ml.c5.xlarge', train_use_spot_instances=True, train_max_wait=86400, framework_version='1.4.0', py_version='py3', hyperparameters=hps) 

请特别注意 train_use_spot_instances 参数,True 值代表您希望优先使用 SPOT 实例。由于机器学习训练工作通常需要大量计算资源长时间运行,善用 SPOT 可以帮助您实现有效的成本控制,SPOT 实例价格可能是按需实例价格的 20% 到 60%,依据选择实例类型、区域、时间不同实际价格有所不同。

您已经创建了 PyTorch 对象,下面可以用它来拟合预先存在 Amazon S3 上的数据了。下面的指令将执行训练任务,训练数据将以名为 QMNIST 的输入通道的方式导入训练环境。训练开始执行过程中,Amazon S3 上的训练数据将被下载到模型训练环境的本地文件系统,训练脚本 train.py 将从本地磁盘加载数据进行训练。

# Start training estimator.fit({'QMNIST': s3_data_location}, wait=False) 

根据您选择的训练实例不同,训练过程中可能持续几十分钟到几个小时不等。建议设置 wait 参数为 False ,这个选项将使笔记本与训练任务分离,在训练时间长、训练日志多的场景下,可以避免笔记本上下文因为网络中断或者会话超时而丢失。训练任务脱离笔记本后,输出将暂时不可见,可以执行如下代码,笔记本将获取并载入此前的训练回话,

%%time from sagemaker.estimator import Estimator  # Attaching previous training session training_job_name = estimator.latest_training_job.name attached_estimator = Estimator.attach(training_job_name) 

由于的模型设计考虑到了GPU对训练加速的能力,所以用GPU实例训练会比CPU实例快一些,例如,p3.2xlarge 实例大概需要15分钟左右,而 c5.xlarge 实例则可能需要6小时以上。目前模型不支持分布、并行训练,所以多实例、多CPU/GPU并不会带来更多的训练速度提升。

训练完成后,模型将被上传到 Amazon S3 里,上传位置由创建 PyTorch 对象时提供的 output_path 参数指定。

模型的验证

您将从 Amazon S3 下载经过训练的模型到笔记本所在实例的本地文件系统,下面的代码将载入模型,然后输入一个随机数,获得推理结果,以图片形式展现出来。执行如下指令加载训练好的模型,并通过这个模型产生一组『手写』数字字体。

from helper import * import matplotlib.pyplot as plt import numpy as np import torch from dcgan.model import Generator  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  params = {'nz': nz, 'nc': nc, 'ngf': ngf} model = load_model(Generator, params, "./model/generator_state.pth", device=device) img = generate_fake_handwriting(model, batch_size=batch_size, nz=nz, device=device)  plt.imshow(np.asarray(img))  轻松构建 PyTorch 生成对抗网络(GAN)

结论与总结

近些年成长快速的 PyTorch 框架正在得到广泛的认可和应用,越来越多的新模型采用 PyTorch 框架,也有模型被迁移到 PyTorch 上,或者基于 PyTorch 被完整再实现。生态环境持续丰富,应用领域不断拓展,PyTorch 已成为事实上的主流框架之一。Amazon SageMaker 与多种 AWS 服务紧密集成,比如,各种类型和尺寸的 Amazon EC2 计算实例、Amazon S3、Amazon ECR 等等,为机器学习工程实践提供了端到端的、一致的体验。Amazon SageMaker 持续支持主流机器学习框架,PyTorch 是这其中之一。用 PyTorch 开发的机器学习算法和模型,可以轻松移植到 Amazon SageMaker 的工程和服务环境里,进而利用 Amazon SageMaker 全托管的 Jupyter Notebook、训练容器镜像、服务容器镜像、训练任务管理、部署环境托管等功能,简化机器学习工程复杂度,提高生产效率,降低运维成本。

(编辑:应用网_阳江站长网)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

推荐文章
    热点阅读