From 64c766acd01b6ebdcc246f71ddc5e7e6a163f209 Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Wed, 16 Sep 2020 10:13:26 +0000 Subject: [PATCH] adapt to 2.0 api --- handwritten_number_recognition/mnist.py | 32 +++++----- image_classification/imagenet_dataset.py | 5 +- image_classification/main.py | 41 ++++++------- style-transfer/README.md | 61 ++++++++++--------- style-transfer/style-transfer.ipynb | 75 +++++++++++++----------- style-transfer/style_transfer.py | 46 ++++++++------- 6 files changed, 134 insertions(+), 126 deletions(-) diff --git a/handwritten_number_recognition/mnist.py b/handwritten_number_recognition/mnist.py index a3b77da..a32f041 100644 --- a/handwritten_number_recognition/mnist.py +++ b/handwritten_number_recognition/mnist.py @@ -16,38 +16,36 @@ from __future__ import division from __future__ import print_function import argparse +import paddle from paddle import fluid from paddle.fluid.optimizer import Momentum -from paddle.incubate.hapi.datasets.mnist import MNIST as MnistDataset +from paddle.vision.datasets.mnist import MNIST -from paddle.incubate.hapi.model import Input, set_device -from paddle.incubate.hapi.loss import CrossEntropy -from paddle.incubate.hapi.metrics import Accuracy -from paddle.incubate.hapi.vision.models import LeNet +from paddle.vision.models import LeNet +from paddle.static import InputSpec as Input def main(): - device = set_device(FLAGS.device) - fluid.enable_dygraph(device) if FLAGS.dynamic else None + device = paddle.set_device(FLAGS.device) + paddle.disable_static(device) if FLAGS.dynamic else None - train_dataset = MnistDataset(mode='train') - val_dataset = MnistDataset(mode='test') + train_dataset = MNIST(mode='train') + val_dataset = MNIST(mode='test') - inputs = [Input([None, 1, 28, 28], 'float32', name='image')] - labels = [Input([None, 1], 'int64', name='label')] + inputs = [Input(shape=[None, 1, 28, 28], dtype='float32', name='image')] + labels = [Input(shape=[None, 1], dtype='int64', name='label')] + + net = LeNet() + model = paddle.Model(net, inputs, labels) - model = LeNet() optim = Momentum( learning_rate=FLAGS.lr, momentum=.9, parameter_list=model.parameters()) model.prepare( optim, - CrossEntropy(), - Accuracy(topk=(1, 2)), - inputs, - labels, - device=FLAGS.device) + paddle.nn.CrossEntropyLoss(), + paddle.metric.Accuracy(topk=(1, 2))) if FLAGS.resume is not None: model.load(FLAGS.resume) diff --git a/image_classification/imagenet_dataset.py b/image_classification/imagenet_dataset.py index 27a9009..af5ccca 100644 --- a/image_classification/imagenet_dataset.py +++ b/image_classification/imagenet_dataset.py @@ -18,9 +18,8 @@ import math import random import numpy as np -from paddle.incubate.hapi.datasets import DatasetFolder -from paddle.incubate.hapi.vision.transforms import transforms -from paddle import fluid +from paddle.vision.datasets import DatasetFolder +from paddle.vision.transforms import transforms class ImageNetDataset(DatasetFolder): diff --git a/image_classification/main.py b/image_classification/main.py index ff0c95b..33bc236 100644 --- a/image_classification/main.py +++ b/image_classification/main.py @@ -15,25 +15,19 @@ from __future__ import division from __future__ import print_function -import argparse -import contextlib import os - import time -import math +import argparse import numpy as np +import paddle import paddle.fluid as fluid -from paddle.fluid.dygraph.parallel import ParallelEnv -from paddle.io import BatchSampler, DataLoader - -from paddle.incubate.hapi.model import Input, set_device -from paddle.incubate.hapi.loss import CrossEntropy -from paddle.incubate.hapi.distributed import DistributedBatchSampler -from paddle.incubate.hapi.metrics import Accuracy -import paddle.incubate.hapi.vision.models as models +import paddle.vision.models as models +from paddle.static import InputSpec as Input from imagenet_dataset import ImageNetDataset +from paddle.distributed import ParallelEnv +from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler def make_optimizer(step_per_epoch, parameter_list=None): @@ -72,21 +66,23 @@ def make_optimizer(step_per_epoch, parameter_list=None): def main(): - device = set_device(FLAGS.device) - fluid.enable_dygraph(device) if FLAGS.dynamic else None + device = paddle.set_device(FLAGS.device) + paddle.disable_static(device) if FLAGS.dynamic else None model_list = [x for x in models.__dict__["__all__"]] assert FLAGS.arch in model_list, "Expected FLAGS.arch in {}, but received {}".format( model_list, FLAGS.arch) - model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only and - not FLAGS.resume) - - if FLAGS.resume is not None: - model.load(FLAGS.resume) + net = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only and + not FLAGS.resume) inputs = [Input([None, 3, 224, 224], 'float32', name='image')] labels = [Input([None, 1], 'int64', name='label')] + model = paddle.Model(net, inputs, labels) + + if FLAGS.resume is not None: + model.load(FLAGS.resume) + train_dataset = ImageNetDataset( os.path.join(FLAGS.data, 'train'), mode='train', @@ -106,11 +102,8 @@ def main(): model.prepare( optim, - CrossEntropy(), - Accuracy(topk=(1, 5)), - inputs, - labels, - FLAGS.device) + paddle.nn.CrossEntropyLoss(), + paddle.metric.Accuracy(topk=(1, 5))) if FLAGS.eval_only: model.evaluate( diff --git a/style-transfer/README.md b/style-transfer/README.md index 46b2b30..b9d8c17 100644 --- a/style-transfer/README.md +++ b/style-transfer/README.md @@ -15,9 +15,9 @@ ```python # tensor shape is [1, c, h, w] _, c, h, w = tensor.shape -tensor = fluid.layers.reshape(tensor, [c, h * w]) +tensor = paddle.reshape(tensor, [c, h * w]) # gram matrix with shape: [c, c] -gram_matrix = fluid.layers.matmul(tensor, fluid.layers.transpose(tensor, [1, 0])) +gram_matrix = paddle.matmul(tensor, paddle.transpose(tensor, [1, 0])) ``` 最终风格迁移的问题转化为优化上述的两个欧式距离的问题。这里要注意的是,我们使用一个在imagenet上预训练好的模型vgg16,并且固定参数,优化器只更新输入的生成图像的值。 @@ -32,12 +32,11 @@ gram_matrix = fluid.layers.matmul(tensor, fluid.layers.transpose(tensor, [1, 0]) import numpy as np import matplotlib.pyplot as plt -from paddle.incubate.hapi.model import Model, Loss +import paddle -from paddle.incubate.hapi.vision.models import vgg16 -from paddle.incubate.hapi.vision.transforms import transforms +from paddle.vision.models import vgg16 +from paddle.vision.transforms import transforms from paddle import fluid -from paddle.fluid.io import Dataset import cv2 import copy @@ -49,7 +48,7 @@ from .style_transfer import load_image, image_restore ```python # 启动动态图模式 -fluid.enable_dygraph() +paddle.disable_static() ``` ```python @@ -77,22 +76,23 @@ ax2.imshow(image_restore(style)) ```python # 定义风格迁移模型,使用在imagenet上预训练好的vgg16作为基础模型 -class StyleTransferModel(Model): +class StyleTransferModel(paddle.nn.Layer): def __init__(self): super(StyleTransferModel, self).__init__() # pretrained设置为true,会自动下载imagenet上的预训练权重并加载 vgg = vgg16(pretrained=True) self.base_model = vgg.features + for p in self.base_model.parameters(): - p.stop_gradient=True + p.stop_gradient = True self.layers = { - '0': 'conv1_1', - '3': 'conv2_1', - '6': 'conv3_1', - '10': 'conv4_1', - '11': 'conv4_2', ## content representation - '14': 'conv5_1' - } + '0': 'conv1_1', + '5': 'conv2_1', + '10': 'conv3_1', + '17': 'conv4_1', + '19': 'conv4_2', ## content representation + '24': 'conv5_1' + } def forward(self, image): outputs = [] @@ -106,27 +106,33 @@ class StyleTransferModel(Model): ```python # 定义风格迁移个损失函数 -class StyleTransferLoss(Loss): - def __init__(self, content_loss_weight=1, style_loss_weight=1e5, style_weights=[1.0, 0.8, 0.5, 0.3, 0.1]): +class StyleTransferLoss(paddle.nn.Layer): + def __init__(self, + content_loss_weight=1, + style_loss_weight=1e5, + style_weights=[1.0, 0.8, 0.5, 0.3, 0.1]): super(StyleTransferLoss, self).__init__() self.content_loss_weight = content_loss_weight self.style_loss_weight = style_loss_weight self.style_weights = style_weights - def forward(self, outputs, labels): + def forward(self, *features): + outputs = features[:6] + labels = features[6:] content_features = labels[-1] style_features = labels[:-1] # 计算图像内容相似度的loss - content_loss = fluid.layers.mean((outputs[-2] - content_features)**2) + content_loss = paddle.mean((outputs[-2] - content_features)**2) # 计算风格相似度的loss style_loss = 0 - style_grams = [self.gram_matrix(feat) for feat in style_features ] + style_grams = [self.gram_matrix(feat) for feat in style_features] style_weights = self.style_weights for i, weight in enumerate(style_weights): target_gram = self.gram_matrix(outputs[i]) - layer_loss = weight * fluid.layers.mean((target_gram - style_grams[i])**2) + layer_loss = weight * paddle.mean((target_gram - style_grams[ + i])**2) b, d, h, w = outputs[i].shape style_loss += layer_loss / (d * h * w) @@ -135,9 +141,9 @@ class StyleTransferLoss(Loss): def gram_matrix(self, A): if len(A.shape) == 4: - batch_size, c, h, w = A.shape - A = fluid.layers.reshape(A, (c, h*w)) - GA = fluid.layers.matmul(A, fluid.layers.transpose(A, [1, 0])) + _, c, h, w = A.shape + A = paddle.reshape(A, (c, h * w)) + GA = paddle.matmul(A, paddle.transpose(A, [1, 0])) return GA ``` @@ -145,7 +151,8 @@ class StyleTransferLoss(Loss): ```python # 创建模型 -model = StyleTransferModel() +net = StyleTransferModel() +model = paddle.Model(net) ``` @@ -157,7 +164,7 @@ style_loss = StyleTransferLoss() ```python # 使用内容图像初始化要生成的图像 -target = Model.create_parameter(model, shape=content.shape) +target = net.create_parameter(shape=content.shape) target.set_value(content.numpy()) ``` diff --git a/style-transfer/style-transfer.ipynb b/style-transfer/style-transfer.ipynb index eca85f5..4603c1e 100644 --- a/style-transfer/style-transfer.ipynb +++ b/style-transfer/style-transfer.ipynb @@ -36,12 +36,11 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", - "from hapi.model import Model, Loss\n", + "import paddle\n", "\n", - "from hapi.vision.models import vgg16\n", - "from hapi.vision.transforms import transforms\n", + "from paddle.vision.models import vgg16\n", + "from paddle.vision.transforms import transforms\n", "from paddle import fluid\n", - "from paddle.fluid.io import Dataset\n", "\n", "import cv2\n", "import copy" @@ -54,7 +53,7 @@ "outputs": [], "source": [ "# 启动动态图模式\n", - "fluid.enable_dygraph()" + "paddle.disable_static()" ] }, { @@ -67,9 +66,9 @@ "```python\n", "# tensor shape is [1, c, h, w]\n", "_, c, h, w = tensor.shape\n", - "tensor = fluid.layers.reshape(tensor, [c, h * w])\n", + "tensor = paddle.reshape(tensor, [c, h * w])\n", "# gram matrix with shape: [c, c]\n", - "gram_matrix = fluid.layers.matmul(tensor, fluid.layers.transpose(tensor, [1, 0]))\n", + "gram_matrix = paddle.matmul(tensor, paddle.transpose(tensor, [1, 0]))\n", "```\n", "\n", "最终风格迁移的问题转化为优化上述的两个欧式距离的问题。这里要注意的是,我们使用一个在imagenet上预训练好的模型vgg16,并且固定参数,优化器只更新输入的生成图像的值。" @@ -176,23 +175,24 @@ "outputs": [], "source": [ "# 定义风格迁移模型,使用在imagenet上预训练好的vgg16作为基础模型\n", - "class StyleTransferModel(Model):\n", + "class StyleTransferModel(paddle.nn.Layer):\n", " def __init__(self):\n", " super(StyleTransferModel, self).__init__()\n", " # pretrained设置为true,会自动下载imagenet上的预训练权重并加载\n", " vgg = vgg16(pretrained=True)\n", " self.base_model = vgg.features\n", + "\n", " for p in self.base_model.parameters():\n", - " p.stop_gradient=True\n", + " p.stop_gradient = True\n", " self.layers = {\n", - " '0': 'conv1_1',\n", - " '3': 'conv2_1', \n", - " '6': 'conv3_1', \n", - " '10': 'conv4_1',\n", - " '11': 'conv4_2', ## content representation\n", - " '14': 'conv5_1'\n", - " }\n", - " \n", + " '0': 'conv1_1',\n", + " '5': 'conv2_1',\n", + " '10': 'conv3_1',\n", + " '17': 'conv4_1',\n", + " '19': 'conv4_2', ## content representation\n", + " '24': 'conv5_1'\n", + " }\n", + "\n", " def forward(self, image):\n", " outputs = []\n", " for name, layer in self.base_model.named_sublayers():\n", @@ -208,38 +208,44 @@ "metadata": {}, "outputs": [], "source": [ - "class StyleTransferLoss(Loss):\n", - " def __init__(self, content_loss_weight=1, style_loss_weight=1e5, style_weights=[1.0, 0.8, 0.5, 0.3, 0.1]):\n", + "class StyleTransferLoss(paddle.nn.Layer):\n", + " def __init__(self,\n", + " content_loss_weight=1,\n", + " style_loss_weight=1e5,\n", + " style_weights=[1.0, 0.8, 0.5, 0.3, 0.1]):\n", " super(StyleTransferLoss, self).__init__()\n", " self.content_loss_weight = content_loss_weight\n", " self.style_loss_weight = style_loss_weight\n", " self.style_weights = style_weights\n", - " \n", - " def forward(self, outputs, labels):\n", + "\n", + " def forward(self, *features):\n", + " outputs = features[:6]\n", + " labels = features[6:]\n", " content_features = labels[-1]\n", " style_features = labels[:-1]\n", - " \n", + "\n", " # 计算图像内容相似度的loss\n", - " content_loss = fluid.layers.mean((outputs[-2] - content_features)**2)\n", - " \n", + " content_loss = paddle.mean((outputs[-2] - content_features)**2)\n", + "\n", " # 计算风格相似度的loss\n", " style_loss = 0\n", - " style_grams = [self.gram_matrix(feat) for feat in style_features ]\n", + " style_grams = [self.gram_matrix(feat) for feat in style_features]\n", " style_weights = self.style_weights\n", " for i, weight in enumerate(style_weights):\n", " target_gram = self.gram_matrix(outputs[i])\n", - " layer_loss = weight * fluid.layers.mean((target_gram - style_grams[i])**2)\n", + " layer_loss = weight * paddle.mean((target_gram - style_grams[\n", + " i])**2)\n", " b, d, h, w = outputs[i].shape\n", " style_loss += layer_loss / (d * h * w)\n", - " \n", + "\n", " total_loss = self.content_loss_weight * content_loss + self.style_loss_weight * style_loss\n", " return total_loss\n", - " \n", + "\n", " def gram_matrix(self, A):\n", " if len(A.shape) == 4:\n", - " batch_size, c, h, w = A.shape\n", - " A = fluid.layers.reshape(A, (c, h*w))\n", - " GA = fluid.layers.matmul(A, fluid.layers.transpose(A, [1, 0]))\n", + " _, c, h, w = A.shape\n", + " A = paddle.reshape(A, (c, h * w))\n", + " GA = paddle.matmul(A, paddle.transpose(A, [1, 0]))\n", "\n", " return GA" ] @@ -260,7 +266,8 @@ ], "source": [ "# 创建模型\n", - "model = StyleTransferModel()" + "net = StyleTransferModel()\n", + "model = paddle.Model(net)" ] }, { @@ -280,7 +287,7 @@ "outputs": [], "source": [ "# 使用内容图像初始化要生成的图像\n", - "target = Model.create_parameter(model, shape=content.shape)\n", + "target = net.create_parameter(shape=content.shape)\n", "target.set_value(content.numpy())" ] }, @@ -586,7 +593,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.7.5" } }, "nbformat": 4, diff --git a/style-transfer/style_transfer.py b/style-transfer/style_transfer.py index 854ca9c..05f1884 100644 --- a/style-transfer/style_transfer.py +++ b/style-transfer/style_transfer.py @@ -3,12 +3,11 @@ import argparse import numpy as np import matplotlib.pyplot as plt -from paddle.incubate.hapi.model import Model, Loss +import paddle -from paddle.incubate.hapi.vision.models import vgg16 -from paddle.incubate.hapi.vision.transforms import transforms +from paddle.vision.models import vgg16 +from paddle.vision.transforms import transforms from paddle import fluid -from paddle.fluid.io import Dataset import cv2 import copy @@ -25,7 +24,7 @@ def load_image(image_path, max_size=400, shape=None): transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image = transform(image)[np.newaxis, :3, :, :] - image = fluid.dygraph.to_variable(image) + image = paddle.to_tensor(image) return image @@ -39,21 +38,22 @@ def image_restore(image): return image -class StyleTransferModel(Model): +class StyleTransferModel(paddle.nn.Layer): def __init__(self): super(StyleTransferModel, self).__init__() # pretrained设置为true,会自动下载imagenet上的预训练权重并加载 vgg = vgg16(pretrained=True) self.base_model = vgg.features + for p in self.base_model.parameters(): p.stop_gradient = True self.layers = { '0': 'conv1_1', - '3': 'conv2_1', - '6': 'conv3_1', - '10': 'conv4_1', - '11': 'conv4_2', ## content representation - '14': 'conv5_1' + '5': 'conv2_1', + '10': 'conv3_1', + '17': 'conv4_1', + '19': 'conv4_2', ## content representation + '24': 'conv5_1' } def forward(self, image): @@ -65,7 +65,7 @@ class StyleTransferModel(Model): return outputs -class StyleTransferLoss(Loss): +class StyleTransferLoss(paddle.nn.Layer): def __init__(self, content_loss_weight=1, style_loss_weight=1e5, @@ -75,12 +75,14 @@ class StyleTransferLoss(Loss): self.style_loss_weight = style_loss_weight self.style_weights = style_weights - def forward(self, outputs, labels): + def forward(self, *features): + outputs = features[:6] + labels = features[6:] content_features = labels[-1] style_features = labels[:-1] # 计算图像内容相似度的loss - content_loss = fluid.layers.mean((outputs[-2] - content_features)**2) + content_loss = paddle.mean((outputs[-2] - content_features)**2) # 计算风格相似度的loss style_loss = 0 @@ -88,8 +90,8 @@ class StyleTransferLoss(Loss): style_weights = self.style_weights for i, weight in enumerate(style_weights): target_gram = self.gram_matrix(outputs[i]) - layer_loss = weight * fluid.layers.mean((target_gram - style_grams[ - i])**2) + layer_loss = weight * paddle.mean((target_gram - style_grams[i])** + 2) b, d, h, w = outputs[i].shape style_loss += layer_loss / (d * h * w) @@ -99,24 +101,26 @@ class StyleTransferLoss(Loss): def gram_matrix(self, A): if len(A.shape) == 4: _, c, h, w = A.shape - A = fluid.layers.reshape(A, (c, h * w)) - GA = fluid.layers.matmul(A, fluid.layers.transpose(A, [1, 0])) + A = paddle.reshape(A, (c, h * w)) + GA = paddle.matmul(A, paddle.transpose(A, [1, 0])) return GA def main(): # 启动动态图模式 - fluid.enable_dygraph() + paddle.disable_static() content = load_image(FLAGS.content_image) style = load_image(FLAGS.style_image, shape=tuple(content.shape[-2:])) - model = StyleTransferModel() + net = StyleTransferModel() + model = paddle.Model(net) + style_loss = StyleTransferLoss() # 使用内容图像初始化要生成的图像 - target = Model.create_parameter(model, shape=content.shape) + target = net.create_parameter(shape=content.shape) target.set_value(content.numpy()) optimizer = fluid.optimizer.Adam( -- GitLab