diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 7c17bb07c0c24a15bd5faf93ab1cfafef83b0d6e..1d2ea142c7d5f2e653e446986a39d1bc155006f0 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -556,89 +556,92 @@ class TranslatedLayer(layers.Layer): .. code-block:: python import numpy as np - import paddle.fluid as fluid - from paddle.fluid.dygraph import Linear - from paddle.fluid.dygraph import declarative + import paddle + import paddle.nn as nn + import paddle.optimizer as opt - BATCH_SIZE = 32 - BATCH_NUM = 20 + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 - def random_batch_reader(): - def _get_random_images_and_labels(image_shape, label_shape): - image = np.random.random(size=image_shape).astype('float32') - label = np.random.random(size=label_shape).astype('int64') - return image, label + IMAGE_SIZE = 784 + CLASS_NUM = 10 + + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples - def __reader__(): - for _ in range(BATCH_NUM): - batch_image, batch_label = _get_random_images_and_labels( - [BATCH_SIZE, 784], [BATCH_SIZE, 1]) - yield batch_image, batch_label + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label - return __reader__ + def __len__(self): + return self.num_samples - class LinearNet(fluid.dygraph.Layer): - def __init__(self, in_size, out_size): + class LinearNet(nn.Layer): + def __init__(self): super(LinearNet, self).__init__() - self._linear = Linear(in_size, out_size) + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) - @declarative + @paddle.jit.to_static def forward(self, x): return self._linear(x) + def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + print("Epoch {} batch {}: loss = {}".format( + epoch_id, batch_id, np.mean(loss.numpy()))) + # enable dygraph mode - fluid.enable_dygraph() + place = paddle.CPUPlace() + paddle.disable_static(place) # 1. train & save model. - # create network - net = LinearNet(784, 1) - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters()) - # create data loader - train_loader = fluid.io.DataLoader.from_generator(capacity=5) - train_loader.set_batch_generator(random_batch_reader()) - # train - for data in train_loader(): - img, label = data - label.stop_gradient = True - cost = net(img) + # create network + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters()) - loss = fluid.layers.cross_entropy(cost, label) - avg_loss = fluid.layers.mean(loss) + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader(dataset, + places=place, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) - avg_loss.backward() - adam.minimize(avg_loss) - net.clear_gradients() + # train + train(layer, loader, loss_fn, adam) + # save model_path = "linear.example.model" - fluid.dygraph.jit.save( - layer=net, - model_path=model_path, - input_spec=[img]) + paddle.jit.save(layer, model_path) # 2. load model as TranslatedLayer - translated_layer = fluid.dygraph.jit.load(model_path) + + # load + translated_layer = paddle.jit.load(model_path) + # inference translated_layer.eval() - x = fluid.dygraph.to_variable(np.random.random((1, 784)).astype('float32')) + x = paddle.randn([1, IMAGE_SIZE], 'float32') pred = translated_layer(x) + # fine-tune translated_layer.train() - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=translated_layer.parameters()) - train_loader = fluid.io.DataLoader.from_generator(capacity=5) - train_loader.set_batch_generator(random_batch_reader()) - for data in train_loader(): - img, label = data - label.stop_gradient = True - - cost = translated_layer(img) + adam = opt.Adam(learning_rate=0.001, parameters=translated_layer.parameters()) + train(translated_layer, loader, loss_fn, adam) - loss = fluid.layers.cross_entropy(cost, label) - avg_loss = fluid.layers.mean(loss) - - avg_loss.backward() - adam.minimize(avg_loss) - translated_layer.clear_gradients() """ def __init__(self, programs, persistable_vars): @@ -814,3 +817,107 @@ class TranslatedLayer(layers.Layer): def eval(self): self._is_test = True + + def program(self, method_name='forward'): + """ + Gets translated program of specified method. + + Args: + - method_name (string): mehtod name corresponding to the program + to be obtained. Default: 'forward'. + + Returns: + Program + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn as nn + import paddle.optimizer as opt + + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 + + IMAGE_SIZE = 784 + CLASS_NUM = 10 + + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + + @paddle.jit.to_static + def forward(self, x): + return self._linear(x) + + def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + print("Epoch {} batch {}: loss = {}".format( + epoch_id, batch_id, np.mean(loss.numpy()))) + + # enable dygraph mode + place = paddle.CPUPlace() + paddle.disable_static(place) + + # create network + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters()) + + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader(dataset, + places=place, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + + # train + train(layer, loader, loss_fn, adam) + + # save + model_path = "linear.example.model" + paddle.jit.save(layer, model_path) + + # load + translated_layer = paddle.jit.load(model_path) + + # get program + program = translated_layer.program() + """ + # 1. get program holder + program_holder = self._program_holder_dict.get(method_name, None) + if program_holder is None: + raise ValueError( + "The method `%s` is not exists in loaded TranslatedLayer." % + method_name) + + # 2. get inference program desc + program_desc = program_holder.infer_program + + # 3. construct program + program = _build_program_by_desc(program_desc) + return program diff --git a/python/paddle/fluid/tests/unittests/test_translated_layer.py b/python/paddle/fluid/tests/unittests/test_translated_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..20c51b9afbafac9ba1fa032aea446383bc2b9796 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_translated_layer.py @@ -0,0 +1,157 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.nn as nn +import paddle.optimizer as opt + +BATCH_SIZE = 16 +BATCH_NUM = 4 +EPOCH_NUM = 4 +SEED = 10 + +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +# define a random dataset +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + np.random.seed(SEED) + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + +class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + + @paddle.jit.to_static + def forward(self, x): + return self._linear(x) + + +def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + print("Epoch {} batch {}: loss = {}".format(epoch_id, batch_id, + np.mean(loss.numpy()))) + return loss + + +class TestTranslatedLayer(unittest.TestCase): + def setUp(self): + # enable dygraph mode + place = paddle.CPUPlace() + paddle.disable_static(place) + + # config seed + paddle.manual_seed(SEED) + paddle.framework.random._manual_program_seed(SEED) + + # create network + self.layer = LinearNet() + self.loss_fn = nn.CrossEntropyLoss() + self.sgd = opt.SGD(learning_rate=0.001, + parameters=self.layer.parameters()) + + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + self.loader = paddle.io.DataLoader( + dataset, + places=place, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + + # train + train(self.layer, self.loader, self.loss_fn, self.sgd) + + # save + self.model_path = "linear.example.model" + paddle.jit.save(self.layer, self.model_path) + + def test_inference_and_fine_tuning(self): + self.load_and_inference() + self.load_and_fine_tuning() + + def load_and_inference(self): + # load + translated_layer = paddle.jit.load(self.model_path) + + # inference + x = paddle.randn([1, IMAGE_SIZE], 'float32') + + self.layer.eval() + orig_pred = self.layer(x) + + translated_layer.eval() + pred = translated_layer(x) + + self.assertTrue(np.array_equal(orig_pred.numpy(), pred.numpy())) + + def load_and_fine_tuning(self): + # load + translated_layer = paddle.jit.load(self.model_path) + + # train original layer continue + self.layer.train() + orig_loss = train(self.layer, self.loader, self.loss_fn, self.sgd) + + # fine-tuning + translated_layer.train() + sgd = opt.SGD(learning_rate=0.001, + parameters=translated_layer.parameters()) + loss = train(translated_layer, self.loader, self.loss_fn, sgd) + + self.assertTrue( + np.array_equal(orig_loss.numpy(), loss.numpy()), + msg="original loss:\n{}\nnew loss:\n{}\n".format(orig_loss.numpy(), + loss.numpy())) + + def test_get_program(self): + # load + translated_layer = paddle.jit.load(self.model_path) + + program = translated_layer.program() + self.assertTrue(isinstance(program, paddle.static.Program)) + + def test_get_program_method_not_exists(self): + # load + translated_layer = paddle.jit.load(self.model_path) + + with self.assertRaises(ValueError): + program = translated_layer.program('not_exists') + + +if __name__ == '__main__': + unittest.main()