From c1a886873258d8d461567cefd60348ba01aba27e Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 7 Sep 2020 19:51:53 +0800 Subject: [PATCH] Change jit.save/load configs to config & update code examples (#27056) * change configs to config & update examples * fix deprecate decorator conflict --- python/paddle/fluid/dygraph/checkpoint.py | 16 +- python/paddle/fluid/dygraph/jit.py | 557 +++++++++++----------- 2 files changed, 298 insertions(+), 275 deletions(-) diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index f85b184f68..30ded1f7ed 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -24,7 +24,7 @@ from . import learning_rate_scheduler import warnings from .. import core from .base import guard -from paddle.fluid.dygraph.jit import SaveLoadConfig +from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers __all__ = [ @@ -42,9 +42,9 @@ def deprecate_keep_name_table(func): warnings.warn( "The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.", DeprecationWarning) - configs = SaveLoadConfig() - configs.keep_name_table = keep_name_table - return configs + config = SaveLoadConfig() + config.keep_name_table = keep_name_table + return config # deal with arg `keep_name_table` if len(args) > 1 and isinstance(args[1], bool): @@ -52,7 +52,7 @@ def deprecate_keep_name_table(func): args[1] = __warn_and_build_configs__(args[1]) # deal with kwargs elif 'keep_name_table' in kwargs: - kwargs['configs'] = __warn_and_build_configs__(kwargs[ + kwargs['config'] = __warn_and_build_configs__(kwargs[ 'keep_name_table']) kwargs.pop('keep_name_table') else: @@ -135,8 +135,9 @@ def save_dygraph(state_dict, model_path): # TODO(qingqing01): remove dygraph_only to support loading static model. # maybe need to unify the loading interface after 2.0 API is ready. # @dygraph_only +@deprecate_save_load_configs @deprecate_keep_name_table -def load_dygraph(model_path, configs=None): +def load_dygraph(model_path, config=None): ''' :api_attr: imperative @@ -151,7 +152,7 @@ def load_dygraph(model_path, configs=None): Args: model_path(str) : The file prefix store the state_dict. (The path should Not contain suffix '.pdparams') - configs (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` + config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` object that specifies additional configuration options, these options are for compatibility with ``jit.save/io.save_inference_model`` formats. Default None. @@ -195,6 +196,7 @@ def load_dygraph(model_path, configs=None): opti_file_path = model_prefix + ".pdopt" # deal with argument `configs` + configs = config if configs is None: configs = SaveLoadConfig() diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 9f4ec2b55b..d520fe6188 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -17,6 +17,7 @@ from __future__ import print_function import os import pickle import warnings +import functools import six import paddle @@ -228,63 +229,60 @@ class SaveLoadConfig(object): .. code-block:: python - import numpy as np - import paddle.fluid as fluid - from paddle.fluid.dygraph import Linear - from paddle.fluid.dygraph import declarative + import paddle + import paddle.nn as nn + import paddle.optimizer as opt - class SimpleNet(fluid.dygraph.Layer): + class SimpleNet(nn.Layer): def __init__(self, in_size, out_size): super(SimpleNet, self).__init__() - self._linear = Linear(in_size, out_size) + self._linear = nn.Linear(in_size, out_size) - @declarative + @paddle.jit.to_static def forward(self, x): y = self._linear(x) z = self._linear(y) return z # enable dygraph mode - fluid.enable_dygraph() + paddle.disable_static() # train model net = SimpleNet(8, 8) - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters()) - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) + adam = opt.Adam(learning_rate=0.1, parameters=net.parameters()) + x = paddle.randn([4, 8], 'float32') for i in range(10): out = net(x) - loss = fluid.layers.mean(out) + loss = paddle.tensor.mean(out) loss.backward() - adam.minimize(loss) - net.clear_gradients() + adam.step() + adam.clear_grad() # use SaveLoadconfig when saving model model_path = "simplenet.example.model" - configs = fluid.dygraph.jit.SaveLoadConfig() - configs.model_filename = "__simplenet__" - fluid.dygraph.jit.save( + config = paddle.SaveLoadConfig() + config.model_filename = "__simplenet__" + paddle.jit.save( layer=net, model_path=model_path, - input_spec=[x], - configs=configs) + config=config) 2. Using ``SaveLoadConfig`` when loading model .. code-block:: python - import numpy as np - import paddle.fluid as fluid + import paddle # enable dygraph mode - fluid.enable_dygraph() + paddle.disable_static() # use SaveLoadconfig when loading model model_path = "simplenet.example.model" - configs = fluid.dygraph.jit.SaveLoadConfig() - configs.model_filename = "__simplenet__" - infer_net = fluid.dygraph.jit.load(model_path, configs=configs) + config = paddle.SaveLoadConfig() + config.model_filename = "__simplenet__" + infer_net = paddle.jit.load(model_path, config=config) # inference - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) + x = paddle.randn([4, 8], 'float32') pred = infer_net(x) """ @@ -324,51 +322,46 @@ class SaveLoadConfig(object): Examples: .. code-block:: python - import numpy as np - import paddle.fluid as fluid - from paddle.fluid.dygraph import Linear - from paddle.fluid.dygraph import declarative + import paddle + import paddle.nn as nn + import paddle.optimizer as opt - class SimpleNet(fluid.dygraph.Layer): + class SimpleNet(nn.Layer): def __init__(self, in_size, out_size): super(SimpleNet, self).__init__() - self._linear = Linear(in_size, out_size) + self._linear = nn.Linear(in_size, out_size) - @declarative + @paddle.jit.to_static def forward(self, x): y = self._linear(x) z = self._linear(y) - loss = fluid.layers.mean(z) + loss = paddle.tensor.mean(z) return z, loss # enable dygraph mode - fluid.enable_dygraph() + paddle.disable_static() # train model net = SimpleNet(8, 8) - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters()) - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) + adam = opt.Adam(learning_rate=0.1, parameters=net.parameters()) + x = paddle.randn([4, 8], 'float32') for i in range(10): out, loss = net(x) loss.backward() - adam.minimize(loss) - net.clear_gradients() + adam.step() + adam.clear_grad() # use SaveLoadconfig.output_spec model_path = "simplenet.example.model.output_spec" - configs = fluid.dygraph.jit.SaveLoadConfig() - # only keep the predicted output in saved model, discard loss - configs.output_spec = [out] - - fluid.dygraph.jit.save( + config = paddle.SaveLoadConfig() + config.output_spec = [out] + paddle.jit.save( layer=net, model_path=model_path, - input_spec=[x], - configs=configs) + config=config) - infer_net = fluid.dygraph.jit.load(model_path, configs=configs) - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) - # only have the predicted output + infer_net = paddle.jit.load(model_path) + x = paddle.randn([4, 8], 'float32') pred = infer_net(x) """ return self._output_spec @@ -395,52 +388,47 @@ class SaveLoadConfig(object): Examples: .. code-block:: python - import numpy as np - import paddle.fluid as fluid - from paddle.fluid.dygraph import Linear - from paddle.fluid.dygraph import declarative + import paddle + import paddle.nn as nn + import paddle.optimizer as opt - class SimpleNet(fluid.dygraph.Layer): + class SimpleNet(nn.Layer): def __init__(self, in_size, out_size): super(SimpleNet, self).__init__() - self._linear = Linear(in_size, out_size) + self._linear = nn.Linear(in_size, out_size) - @declarative + @paddle.jit.to_static def forward(self, x): y = self._linear(x) z = self._linear(y) return z # enable dygraph mode - fluid.enable_dygraph() + paddle.disable_static() # train model net = SimpleNet(8, 8) - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters()) - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) + adam = opt.Adam(learning_rate=0.1, parameters=net.parameters()) + x = paddle.randn([4, 8], 'float32') for i in range(10): out = net(x) - loss = fluid.layers.mean(out) + loss = paddle.tensor.mean(out) loss.backward() - adam.minimize(loss) - net.clear_gradients() - - model_path = "simplenet.example.model.model_filename" - configs = fluid.dygraph.jit.SaveLoadConfig() - configs.model_filename = "__simplenet__" + adam.step() + adam.clear_grad() # saving with configs.model_filename - fluid.dygraph.jit.save( + model_path = "simplenet.example.model.model_filename" + config = paddle.SaveLoadConfig() + config.model_filename = "__simplenet__" + paddle.jit.save( layer=net, model_path=model_path, - input_spec=[x], - configs=configs) - # [result] the saved model directory contains: - # __simplenet__ __variables__ __variables.info__ + config=config) # loading with configs.model_filename - infer_net = fluid.dygraph.jit.load(model_path, configs=configs) - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) + infer_net = paddle.jit.load(model_path, config=config) + x = paddle.randn([4, 8], 'float32') pred = infer_net(x) """ return self._model_filename @@ -465,52 +453,48 @@ class SaveLoadConfig(object): Examples: .. code-block:: python - import numpy as np - import paddle.fluid as fluid - from paddle.fluid.dygraph import Linear - from paddle.fluid.dygraph import declarative + import paddle + import paddle.nn as nn + import paddle.optimizer as opt - class SimpleNet(fluid.dygraph.Layer): + class SimpleNet(nn.Layer): def __init__(self, in_size, out_size): super(SimpleNet, self).__init__() - self._linear = Linear(in_size, out_size) + self._linear = nn.Linear(in_size, out_size) - @declarative + @paddle.jit.to_static def forward(self, x): y = self._linear(x) z = self._linear(y) return z # enable dygraph mode - fluid.enable_dygraph() + paddle.disable_static() # train model net = SimpleNet(8, 8) - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters()) - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) + adam = opt.Adam(learning_rate=0.1, parameters=net.parameters()) + x = paddle.randn([4, 8], 'float32') for i in range(10): out = net(x) - loss = fluid.layers.mean(out) + loss = paddle.tensor.mean(out) loss.backward() - adam.minimize(loss) - net.clear_gradients() + adam.step() + adam.clear_grad() model_path = "simplenet.example.model.params_filename" - configs = fluid.dygraph.jit.SaveLoadConfig() - configs.params_filename = "__params__" + config = paddle.SaveLoadConfig() + config.params_filename = "__params__" # saving with configs.params_filename - fluid.dygraph.jit.save( + paddle.jit.save( layer=net, model_path=model_path, - input_spec=[x], - configs=configs) - # [result] the saved model directory contains: - # __model__ __params__ __variables.info__ + config=config) # loading with configs.params_filename - infer_net = fluid.dygraph.jit.load(model_path, configs=configs) - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) + infer_net = paddle.jit.load(model_path, config=config) + x = paddle.randn([4, 8], 'float32') pred = infer_net(x) """ return self._params_filename @@ -544,52 +528,50 @@ class SaveLoadConfig(object): Examples: .. code-block:: python - import numpy as np - import paddle.fluid as fluid - from paddle.fluid.dygraph import Linear - from paddle.fluid.dygraph import declarative + import paddle + import paddle.nn as nn + import paddle.optimizer as opt - class SimpleNet(fluid.dygraph.Layer): + class SimpleNet(nn.Layer): def __init__(self, in_size, out_size): super(SimpleNet, self).__init__() - self._linear = Linear(in_size, out_size) + self._linear = nn.Linear(in_size, out_size) - @declarative + @paddle.jit.to_static def forward(self, x): y = self._linear(x) z = self._linear(y) return z # enable dygraph mode - fluid.enable_dygraph() + paddle.disable_static() # train model net = SimpleNet(8, 8) - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters()) - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) + adam = opt.Adam(learning_rate=0.1, parameters=net.parameters()) + x = paddle.randn([4, 8], 'float32') for i in range(10): out = net(x) - loss = fluid.layers.mean(out) + loss = paddle.tensor.mean(out) loss.backward() - adam.minimize(loss) - net.clear_gradients() + adam.step() + adam.clear_grad() model_path = "simplenet.example.model.separate_params" - configs = fluid.dygraph.jit.SaveLoadConfig() - configs.separate_params = True + config = paddle.jit.SaveLoadConfig() + config.separate_params = True # saving with configs.separate_params - fluid.dygraph.jit.save( + paddle.jit.save( layer=net, model_path=model_path, - input_spec=[x], - configs=configs) + config=config) # [result] the saved model directory contains: # linear_0.b_0 linear_0.w_0 __model__ __variables.info__ # loading with configs.params_filename - infer_net = fluid.dygraph.jit.load(model_path, configs=configs) - x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32')) + infer_net = paddle.jit.load(model_path, config=config) + x = paddle.randn([4, 8], 'float32') pred = infer_net(x) """ return self._separate_params @@ -651,8 +633,21 @@ class SaveLoadConfig(object): self._keep_name_table = value +# NOTE(chenweihang): change jit.save/load argument `configs` to `config` +def deprecate_save_load_configs(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if 'configs' in kwargs: + kwargs['config'] = kwargs['configs'] + kwargs.pop('configs') + return func(*args, **kwargs) + + return wrapper + + +@deprecate_save_load_configs @switch_to_static_graph -def save(layer, model_path, input_spec=None, configs=None): +def save(layer, model_path, input_spec=None, config=None): """ Saves input declarative Layer as :ref:`api_imperative_TranslatedLayer` format model, which can be used for inference or fine-tuning after loading. @@ -677,7 +672,7 @@ def save(layer, model_path, input_spec=None, configs=None): It is the example inputs that will be passed to saved TranslatedLayer's forward function. If None, all input variables of the original Layer's forward function would be the inputs of the saved model. Default None. - configs (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` object + config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` object that specifies additional configuration options. Default None. Returns: None @@ -686,65 +681,76 @@ def save(layer, model_path, input_spec=None, configs=None): .. code-block:: python import numpy as np - import paddle.fluid as fluid - from paddle.fluid.dygraph import Linear - from paddle.fluid.dygraph import declarative + import paddle + import paddle.nn as nn + import paddle.optimizer as opt - BATCH_SIZE = 32 - BATCH_NUM = 20 + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 - def random_batch_reader(): - def _get_random_images_and_labels(image_shape, label_shape): - image = np.random.random(size=image_shape).astype('float32') - label = np.random.random(size=label_shape).astype('int64') - return image, label + IMAGE_SIZE = 784 + CLASS_NUM = 10 - def __reader__(): - for _ in range(BATCH_NUM): - batch_image, batch_label = _get_random_images_and_labels( - [BATCH_SIZE, 784], [BATCH_SIZE, 1]) - yield batch_image, batch_label + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples - return __reader__ + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label - class LinearNet(fluid.dygraph.Layer): - def __init__(self, in_size, out_size): + def __len__(self): + return self.num_samples + + class LinearNet(nn.Layer): + def __init__(self): super(LinearNet, self).__init__() - self._linear = Linear(in_size, out_size) + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) - @declarative + @paddle.jit.to_static def forward(self, x): return self._linear(x) + def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + print("Epoch {} batch {}: loss = {}".format( + epoch_id, batch_id, np.mean(loss.numpy()))) + # enable dygraph mode - fluid.enable_dygraph() + place = paddle.CPUPlace() + paddle.disable_static(place) - # create network - net = LinearNet(784, 1) - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters()) - # create data loader - train_loader = fluid.io.DataLoader.from_generator(capacity=5) - train_loader.set_batch_generator(random_batch_reader()) - # train - for data in train_loader(): - img, label = data - label.stop_gradient = True + # 1. train & save model. - cost = net(img) + # create network + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters()) - loss = fluid.layers.cross_entropy(cost, label) - avg_loss = fluid.layers.mean(loss) + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader(dataset, + places=place, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) - avg_loss.backward() - adam.minimize(avg_loss) - net.clear_gradients() + # train + train(layer, loader, loss_fn, adam) - # save model + # save model_path = "linear.example.model" - fluid.dygraph.jit.save( - layer=net, - model_path=model_path, - input_spec=[img]) + paddle.jit.save(layer, model_path) """ def get_inout_spec(all_vars, target_vars, return_name=False): @@ -778,6 +784,7 @@ def save(layer, model_path, input_spec=None, configs=None): "The input layer of paddle.jit.save should be 'Layer', but received layer type is %s." % type(layer)) + configs = config if configs is None: configs = SaveLoadConfig() @@ -869,8 +876,9 @@ def save(layer, model_path, input_spec=None, configs=None): pickle.dump(extra_var_info, f, protocol=2) +@deprecate_save_load_configs @dygraph_only -def load(model_path, configs=None): +def load(model_path, config=None): """ :api_attr: imperative @@ -887,7 +895,7 @@ def load(model_path, configs=None): Args: model_path (str): The directory path where the model is saved. - configs (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` object that specifies + config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` object that specifies additional configuration options. Default None. Returns: @@ -899,122 +907,126 @@ def load(model_path, configs=None): .. code-block:: python import numpy as np - import paddle.fluid as fluid - from paddle.fluid.dygraph import Linear - from paddle.fluid.dygraph import declarative + import paddle + import paddle.nn as nn + import paddle.optimizer as opt - BATCH_SIZE = 32 - BATCH_NUM = 20 + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 - def random_batch_reader(): - def _get_random_images_and_labels(image_shape, label_shape): - image = np.random.random(size=image_shape).astype('float32') - label = np.random.random(size=label_shape).astype('int64') - return image, label + IMAGE_SIZE = 784 + CLASS_NUM = 10 + + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples - def __reader__(): - for _ in range(BATCH_NUM): - batch_image, batch_label = _get_random_images_and_labels( - [BATCH_SIZE, 784], [BATCH_SIZE, 1]) - yield batch_image, batch_label + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label - return __reader__ + def __len__(self): + return self.num_samples - class LinearNet(fluid.dygraph.Layer): - def __init__(self, in_size, out_size): + class LinearNet(nn.Layer): + def __init__(self): super(LinearNet, self).__init__() - self._linear = Linear(in_size, out_size) + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) - @declarative + @paddle.jit.to_static def forward(self, x): return self._linear(x) + def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + print("Epoch {} batch {}: loss = {}".format( + epoch_id, batch_id, np.mean(loss.numpy()))) + # enable dygraph mode - fluid.enable_dygraph() + place = paddle.CPUPlace() + paddle.disable_static(place) # 1. train & save model. + # create network - net = LinearNet(784, 1) - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters()) + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters()) + # create data loader - train_loader = fluid.io.DataLoader.from_generator(capacity=5) - train_loader.set_batch_generator(random_batch_reader()) - # train - for data in train_loader(): - img, label = data - label.stop_gradient = True + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader(dataset, + places=place, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) - cost = net(img) + # train + train(layer, loader, loss_fn, adam) - loss = fluid.layers.cross_entropy(cost, label) - avg_loss = fluid.layers.mean(loss) + # save + model_path = "linear.example.model" + paddle.jit.save(layer, model_path) - avg_loss.backward() - adam.minimize(avg_loss) - net.clear_gradients() + # 2. load model - model_path = "linear.example.model" - fluid.dygraph.jit.save( - layer=net, - model_path=model_path, - input_spec=[img]) + # load + loaded_layer = paddle.jit.load(model_path) - # 2. load model & inference - # load model - infer_net = fluid.dygraph.jit.load(model_path) # inference - x = fluid.dygraph.to_variable(np.random.random((1, 784)).astype('float32')) - pred = infer_net(x) + loaded_layer.eval() + x = paddle.randn([1, IMAGE_SIZE], 'float32') + pred = loaded_layer(x) - # 3. load model & fine-tune - # load model - train_net = fluid.dygraph.jit.load(model_path) - train_net.train() - adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=train_net.parameters()) - # create data loader - train_loader = fluid.io.DataLoader.from_generator(capacity=5) - train_loader.set_batch_generator(random_batch_reader()) # fine-tune - for data in train_loader(): - img, label = data - label.stop_gradient = True - - cost = train_net(img) + loaded_layer.train() + adam = opt.Adam(learning_rate=0.001, parameters=loaded_layer.parameters()) + train(loaded_layer, loader, loss_fn, adam) - loss = fluid.layers.cross_entropy(cost, label) - avg_loss = fluid.layers.mean(loss) - - avg_loss.backward() - adam.minimize(avg_loss) - train_net.clear_gradients() 2. Load model saved by :ref:`api_fluid_io_save_inference_model` then performing and fine-tune training. .. code-block:: python import numpy as np + import paddle import paddle.fluid as fluid + import paddle.nn as nn + import paddle.optimizer as opt - BATCH_SIZE = 32 - BATCH_NUM = 20 + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 - def random_batch_reader(): - def _get_random_images_and_labels(image_shape, label_shape): - image = np.random.random(size=image_shape).astype('float32') - label = np.random.random(size=label_shape).astype('int64') - return image, label + IMAGE_SIZE = 784 + CLASS_NUM = 10 - def __reader__(): - for _ in range(BATCH_NUM): - batch_image, batch_label = _get_random_images_and_labels( - [BATCH_SIZE, 784], [BATCH_SIZE, 1]) - yield batch_image, batch_label + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples - return __reader__ + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples - img = fluid.data(name='img', shape=[None, 784], dtype='float32') + image = fluid.data(name='image', shape=[None, 784], dtype='float32') label = fluid.data(name='label', shape=[None, 1], dtype='int64') - pred = fluid.layers.fc(input=img, size=10, act='softmax') + pred = fluid.layers.fc(input=image, size=10, act='softmax') loss = fluid.layers.cross_entropy(input=pred, label=label) avg_loss = fluid.layers.mean(loss) @@ -1025,9 +1037,15 @@ def load(model_path, configs=None): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - loader = fluid.io.DataLoader.from_generator( - feed_list=[img, label], capacity=5, iterable=True) - loader.set_batch_generator(random_batch_reader(), places=place) + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader(dataset, + feed_list=[image, label], + places=place, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) # 1. train and save inference model for data in loader(): @@ -1038,39 +1056,42 @@ def load(model_path, configs=None): model_path = "fc.example.model" fluid.io.save_inference_model( - model_path, ["img"], [pred], exe) + model_path, ["image"], [pred], exe) + + # 2. load model # enable dygraph mode - fluid.enable_dygraph() + paddle.disable_static(place) - # 2. load model & inference - fc = fluid.dygraph.jit.load(model_path) - x = fluid.dygraph.to_variable(np.random.random((1, 784)).astype('float32')) + # load + fc = paddle.jit.load(model_path) + + # inference + fc.eval() + x = paddle.randn([1, IMAGE_SIZE], 'float32') pred = fc(x) - # 3. load model & fine-tune - fc = fluid.dygraph.jit.load(model_path) + # fine-tune fc.train() - sgd = fluid.optimizer.SGD(learning_rate=0.001, - parameter_list=fc.parameters()) - - train_loader = fluid.io.DataLoader.from_generator(capacity=5) - train_loader.set_batch_generator( - random_batch_reader(), places=place) - - for data in train_loader(): - img, label = data - label.stop_gradient = True - - cost = fc(img) - - loss = fluid.layers.cross_entropy(cost, label) - avg_loss = fluid.layers.mean(loss) - - avg_loss.backward() - sgd.minimize(avg_loss) + loss_fn = nn.CrossEntropyLoss() + adam = opt.Adam(learning_rate=0.001, parameters=fc.parameters()) + loader = paddle.io.DataLoader(dataset, + places=place, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = fc(image) + loss = loss_fn(out, label) + loss.backward() + adam.step() + adam.clear_grad() + print("Epoch {} batch {}: loss = {}".format( + epoch_id, batch_id, np.mean(loss.numpy()))) """ - return TranslatedLayer._construct(model_path, configs) + return TranslatedLayer._construct(model_path, config) @dygraph_only -- GitLab