云计算必备知识-基于PyTorch机器学习构建生成对抗网络
此时,可以看到 Amazon SageMaker 为您创建了一个名字类似 * 回到 Notebook instances 页面,您会看到 MySageMakerInstance 笔记本实例显示为 Pending 状态,这个将持续2分钟左右,直到转为 InService 状态。
点击 Open JupyterLab 链接,在新的页面里,您将看到熟悉的 Jupyter Notebook 加载界面。本文默认以 JupyterLab 笔记本作为工程环境,根据您的需要,可以选择使用传统的 Jupyter 笔记本。 您将通过点击 conda_pytorch_p36, 笔记本图标来创建一个叫做 Untitled.ipynb 的笔记本,您可以稍后更改它的名字。另外,您也可以通过 File > New > Notebook 菜单路径,并选择 conda_pytorch_p36 作为 Kernel 来创建这个笔记本。 在新建的 Untitled.ipynb 笔记本里,我们将输入第一行指令如下, import torch print(f"Hello PyTorch {torch.__version__}") 源代码下载 请在笔记本中输入如下指令,下载代码到实例本地文件系统。 下载完成后,您可以通过 File browser 浏览源代码结构。 本文涉及到的代码和笔记本均通过 Amazon SageMaker 托管的 Python 3.6、PyTorch 1.4 和 JupyterLab 验证。本文涉及到的代码和笔记本可以通过 这里获取。 生成对抗网络模型 算法原理 DCGAN模型的生成网络包含10层,它使用跨步转置卷积层来提高张量的分辨率,输入形状为 (batchsize, 100) ,输出形状为 (batchsize, 64, 64, 3)。换句话说,生成网络接受噪声向量,然后经过不断变换,直到生成最终的图像。 判别网络也包含10层,它接收 (64, 64, 3) 格式的图片,使用2D卷积层进行下采样,最后传递给全链接层进行分类,分类结果是 1 或 0,即真与假。 DCGAN 模型的训练过程大致可以分为三个子过程。 首先, Generator 网络以一个随机数作为输入,生成一张『假』图片;接下来,分别用『真』图片和『假』图片训练 Discriminator 网络,更新参数;最后,更新 Generator 网络参数。 代码分析 项目目录 byos-pytorch-gan 的文件结构如下, 文件 model.py 中包含 3 个类,分别是 生成网络 Generator 和 判别网络 Discriminator。 class Generator(nn.Module): ... class Discriminator(nn.Module): ... class DCGAN(object): """ A wrapper class for Generator and Discriminator, 'train_step' method is for single batch training. """ ... 文件 train.py 用于 Generator 和 Discriminator 两个神经网络的训练,主要包含以下几个方法, def parse_args(): ... def get_datasets(dataset_name, ...): ... def train(dataloader, hps, ...): ... 模型的调试 开发和调试阶段,可以从 Linux 命令行直接运行 train.py 脚本。超参数、输入数据通道、模型和其他训练产出物存放目录都可以通过命令行参数指定。 python dcgan/train.py --dataset qmnist --model-dir '/home/myhome/byom-pytorch-gan/model' --output-dir '/home/myhome/byom-pytorch-gan/tmp' --data-dir '/home/myhome/byom-pytorch-gan/data' --hps '{"beta1":0.5,"dataset":"qmnist","epochs":15,"learning-rate":0.0002,"log-interval":64,"nc":1,"nz":100,"sample-interval":100}' 这样的训练脚本参数设计,既提供了很好的调试方法,又是与 SageMaker Container 集成的规约和必要条件,很好的兼顾了模型开发的自由度和训练环境的可移植性。 模型的训练和验证 请查找并打开名为 dcgan.ipynb 的笔记本文件,训练过程将由这个笔记本介绍并执行,本节内容代码部分从略,请以笔记本代码为准。 互联网环境里有很多公开的数据集,对于机器学习的工程和科研很有帮助,比如算法学习和效果评价。我们将使用 QMNIST 这个手写字体数据集训练模型,最终生成逼真的『手写』字体效果图样。 数据准备 PyTorch 框架的 torchvision.datasets 包提供了QMNIST 数据集,您可以通过如下指令下载 QMNIST 数据集到本地备用。 from torchvision import datasets dataroot = './data' trainset = datasets.QMNIST(root=dataroot, train=True, download=True) testset = datasets.QMNIST(root=dataroot, train=False, download=True) Amazon SageMaker 为您创建了一个默认的 Amazon S3 桶,用来存取机器学习工作流程中可能需要的各种文件和数据。 我们可以通过 SageMaker SDK 中 sagemaker.session.Session 类的 default_bucket 方法获得这个桶的名字。 (编辑:应用网_阳江站长网) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |