加入收藏 | 设为首页 | 会员中心 | 我要投稿 应用网_阳江站长网 (https://www.0662zz.com/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 创业 > 模式 > 正文

云计算必备知识-基于PyTorch机器学习构建生成对抗网络

发布时间:2020-05-28 16:46:22 所属栏目:模式 来源:51cto
导读:生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相

此时,可以看到 Amazon SageMaker 为您创建了一个名字类似 *
AmazonSageMaker-ExecutionRole-**** 的角色。对于其他字段,您可以使用默认值,请点击 Create notebook instance 按钮,创建实例。

轻松构建 PyTorch 生成对抗网络(GAN)

回到 Notebook instances 页面,您会看到 MySageMakerInstance 笔记本实例显示为 Pending 状态,这个将持续2分钟左右,直到转为 InService 状态。

轻松构建 PyTorch 生成对抗网络(GAN)     编写第一行代码

点击 Open JupyterLab 链接,在新的页面里,您将看到熟悉的 Jupyter Notebook 加载界面。本文默认以 JupyterLab 笔记本作为工程环境,根据您的需要,可以选择使用传统的 Jupyter 笔记本。

轻松构建 PyTorch 生成对抗网络(GAN)

您将通过点击 conda_pytorch_p36, 笔记本图标来创建一个叫做 Untitled.ipynb 的笔记本,您可以稍后更改它的名字。另外,您也可以通过 File > New > Notebook 菜单路径,并选择 conda_pytorch_p36 作为 Kernel 来创建这个笔记本。

在新建的 Untitled.ipynb 笔记本里,我们将输入第一行指令如下,

import torch  print(f"Hello PyTorch {torch.__version__}")  源代码下载

请在笔记本中输入如下指令,下载代码到实例本地文件系统。

下载完成后,您可以通过 File browser 浏览源代码结构。

轻松构建 PyTorch 生成对抗网络(GAN)

本文涉及到的代码和笔记本均通过 Amazon SageMaker 托管的 Python 3.6、PyTorch 1.4 和 JupyterLab 验证。本文涉及到的代码和笔记本可以通过 这里获取。

生成对抗网络模型 算法原理

DCGAN模型的生成网络包含10层,它使用跨步转置卷积层来提高张量的分辨率,输入形状为 (batchsize, 100) ,输出形状为 (batchsize, 64, 64, 3)。换句话说,生成网络接受噪声向量,然后经过不断变换,直到生成最终的图像。

判别网络也包含10层,它接收 (64, 64, 3) 格式的图片,使用2D卷积层进行下采样,最后传递给全链接层进行分类,分类结果是 1 或 0,即真与假。

轻松构建 PyTorch 生成对抗网络(GAN)

DCGAN 模型的训练过程大致可以分为三个子过程。

轻松构建 PyTorch 生成对抗网络(GAN)

首先, Generator 网络以一个随机数作为输入,生成一张『假』图片;接下来,分别用『真』图片和『假』图片训练 Discriminator 网络,更新参数;最后,更新 Generator 网络参数。

代码分析

项目目录 byos-pytorch-gan 的文件结构如下,

文件 model.py 中包含 3 个类,分别是 生成网络 Generator 和 判别网络 Discriminator。

class Generator(nn.Module): ...  class Discriminator(nn.Module): ...  class DCGAN(object): """ A wrapper class for Generator and Discriminator, 'train_step' method is for single batch training. """ ... 

文件 train.py 用于 Generator 和 Discriminator 两个神经网络的训练,主要包含以下几个方法,

def parse_args(): ...  def get_datasets(dataset_name, ...): ...  def train(dataloader, hps, ...): ...  模型的调试

开发和调试阶段,可以从 Linux 命令行直接运行 train.py 脚本。超参数、输入数据通道、模型和其他训练产出物存放目录都可以通过命令行参数指定。

python dcgan/train.py --dataset qmnist  --model-dir '/home/myhome/byom-pytorch-gan/model'  --output-dir '/home/myhome/byom-pytorch-gan/tmp'  --data-dir '/home/myhome/byom-pytorch-gan/data'  --hps '{"beta1":0.5,"dataset":"qmnist","epochs":15,"learning-rate":0.0002,"log-interval":64,"nc":1,"nz":100,"sample-interval":100}' 

这样的训练脚本参数设计,既提供了很好的调试方法,又是与 SageMaker Container 集成的规约和必要条件,很好的兼顾了模型开发的自由度和训练环境的可移植性。

模型的训练和验证

请查找并打开名为 dcgan.ipynb 的笔记本文件,训练过程将由这个笔记本介绍并执行,本节内容代码部分从略,请以笔记本代码为准。

互联网环境里有很多公开的数据集,对于机器学习的工程和科研很有帮助,比如算法学习和效果评价。我们将使用 QMNIST 这个手写字体数据集训练模型,最终生成逼真的『手写』字体效果图样。

数据准备

PyTorch 框架的 torchvision.datasets 包提供了QMNIST 数据集,您可以通过如下指令下载 QMNIST 数据集到本地备用。

from torchvision import datasets  dataroot = './data' trainset = datasets.QMNIST(root=dataroot, train=True, download=True) testset = datasets.QMNIST(root=dataroot, train=False, download=True) 

Amazon SageMaker 为您创建了一个默认的 Amazon S3 桶,用来存取机器学习工作流程中可能需要的各种文件和数据。 我们可以通过 SageMaker SDK 中 sagemaker.session.Session 类的 default_bucket 方法获得这个桶的名字。

(编辑:应用网_阳江站长网)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

推荐文章
    热点阅读