From aad3851aad605d9cf1c80d9b45c7279ef54cfaab Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 27 Apr 2020 10:53:46 +0000 Subject: [PATCH] Add comments --- hapi/model.py | 271 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 255 insertions(+), 16 deletions(-) diff --git a/hapi/model.py b/hapi/model.py index cde4ba6..f2bb506 100644 --- a/hapi/model.py +++ b/hapi/model.py @@ -576,15 +576,14 @@ class DynamicGraphAdapter(object): if labels is not None: labels = [to_variable(l) for l in to_list(labels)] if self._nranks > 1: - outputs = self.ddp_model.forward( - * [to_variable(x) for x in inputs]) + outputs = self.ddp_model.forward(*[to_variable(x) for x in inputs]) losses = self.model._loss_function(outputs, labels) final_loss = fluid.layers.sum(losses) final_loss = self.ddp_model.scale_loss(final_loss) final_loss.backward() self.ddp_model.apply_collective_grads() else: - outputs = self.model.forward(* [to_variable(x) for x in inputs]) + outputs = self.model.forward(*[to_variable(x) for x in inputs]) losses = self.model._loss_function(outputs, labels) final_loss = fluid.layers.sum(losses) final_loss.backward() @@ -593,9 +592,9 @@ class DynamicGraphAdapter(object): self.model.clear_gradients() metrics = [] for metric in self.model._metrics: - metric_outs = metric.add_metric_op(*(to_list(outputs) + to_list( - labels))) - m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)]) + metric_outs = metric.add_metric_op(*( + to_list(outputs) + to_list(labels))) + m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) metrics.append(m) return ([to_numpy(l) for l in losses], metrics) \ @@ -607,7 +606,7 @@ class DynamicGraphAdapter(object): inputs = to_list(inputs) if labels is not None: labels = [to_variable(l) for l in to_list(labels)] - outputs = self.model.forward(* [to_variable(x) for x in inputs]) + outputs = self.model.forward(*[to_variable(x) for x in inputs]) if self.model._loss_function: losses = self.model._loss_function(outputs, labels) else: @@ -633,9 +632,9 @@ class DynamicGraphAdapter(object): self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_batch'] = samples - metric_outs = metric.add_metric_op(*(to_list(outputs) + to_list( - labels))) - m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)]) + metric_outs = metric.add_metric_op(*( + to_list(outputs) + to_list(labels))) + m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) metrics.append(m) # To be consistent with static graph @@ -721,7 +720,47 @@ class DynamicGraphAdapter(object): class Model(fluid.dygraph.Layer): """ - FIXME: add more comments and usage + An Model object is network with training and inference features. + Dynamic graph and static graph are supported at the same time, + switched by `fluid.enable_dygraph()`. The usage is as follows. + The switching between dynamic and static should be before + instantiating a Model. The input description, i.e, hapi.Input, + must be required for static graph. + + Usage: + .. code-block:: python + + import numpy as np + import paddle + import paddle.fluid as fluid + #import paddle.incubate.hapi as hapi + import hapi as hapi + + class MyModel(hapi.model.Model): + def __init__(self): + super(MyModel, self).__init__() + self._fc = fluid.dygraph.Linear(784, 10, act='softmax') + def forward(self, x): + y = self._fc(x) + return y + device = hapi.model.set_device('gpu') + # if use static graph, do not set + fluid.enable_dygraph(device) + model = MyModel() + optim = fluid.optimizer.SGD(learning_rate=1e-3, + parameter_list=model.parameters()) + + inputs = [hapi.model.Input([None, 784], 'float32', name='x')] + labels = [hapi.model.Input([None, 1], 'int64', name='label')] + + mnist_data = hapi.datasets.MNIST(mode='train') + model.prepare(optim, + hapi.model.CrossEntropy(), + hapi.metrics.Accuracy(), + inputs, + labels, + device=device) + model.fit(mnist_data, epochs=2, batch_size=32, verbose=1) """ def __init__(self): @@ -742,18 +781,195 @@ class Model(fluid.dygraph.Layer): else: self._adapter = StaticGraphAdapter(self) - def train_batch(self, *args, **kwargs): + def train_batch(self, inputs, labels=None): + """ + Run one training step on a batch of data. + + Args: + inputs (list): A list of numpy.ndarray, each is a batch of + input data. + labels (list): A list of numpy.ndarray, each is a batch of + input label. If has no labels, set None. Default is None. + + Returns: + A list of scalar training loss if the model has no metrics, + or a tuple (list of scalar loss, list of metrics) if the model + set metrics. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + import paddle.incubate.hapi as hapi + + class MyModel(hapi.Model): + def __init__(self): + super(MyModel, self).__init__() + self._fc = Linear(784, 1, act='softmax') + def forward(self, x): + y = self._fc(x) + return y + + device = hapi.set_device('gpu') + fluid.enable_dygraph(device) + + model = MyModel() + optim = fluid.optimizer.SGD(learning_rate=1e-3, + parameter_list=model.parameters()) + + inputs = [hapi.model.Input([None, 784], 'float32', name='x')] + labels = [hapi.model.Input([None, 1], 'int64', name='label')] + model.prepare(optim, + hapi.model.CrossEntropy(), + inputs=inputs, + labels=labels, + device=device) + data = np.random.random(size=(4,784)).astype(np.float32) + label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64) + loss = model.train_batch([data], [label]) + print(loss) + """ return self._adapter.train_batch(*args, **kwargs) - def eval_batch(self, *args, **kwargs): + def eval_batch(self, inputs, labels=None): + """ + Run one evaluating step on a batch of data. + + Args: + inputs (list): A list of numpy.ndarray, each is a batch of + input data. + labels (list): A list of numpy.ndarray, each is a batch of + input label. If has no labels, set None. Default is None. + + Returns: + A list of scalar testing loss if the model has no metrics, + or a tuple (list of scalar loss, list of metrics) if the model + set metrics. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + import paddle.incubate.hapi as hapi + + class MyModel(hapi.model.Model): + def __init__(self): + super(MyModel, self).__init__() + self._fc = fluid.dygraph.Linear(784, 1, act='softmax') + def forward(self, x): + y = self._fc(x) + return y + + device = hapi.model.set_device('gpu') + fluid.enable_dygraph(device) + + model = MyModel() + optim = fluid.optimizer.SGD(learning_rate=1e-3, + parameter_list=model.parameters()) + + inputs = [hapi.model.Input([None, 784], 'float32', name='x')] + labels = [hapi.model.Input([None, 1], 'int64', name='label')] + model.prepare(optim, + hapi.model.CrossEntropy(), + inputs=inputs, + labels=labels, + device=device) + data = np.random.random(size=(4,784)).astype(np.float32) + label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64) + loss = model.eval_batch([data], [label]) + print(loss) + """ return self._adapter.eval_batch(*args, **kwargs) - def test_batch(self, *args, **kwargs): + def test_batch(self, inputs): + """ + Run one testing step on a batch of data. + + Args: + inputs (list): A list of numpy.ndarray, each is a batch of + input data. + + Returns: + A list of numpy.ndarray of predictions, that is the outputs + of Model forward. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + import paddle.incubate.hapi as hapi + + class MyModel(hapi.model.Model): + def __init__(self): + super(MyModel, self).__init__() + self._fc = fluid.dygraph.Linear(784, 1, act='softmax') + def forward(self, x): + y = self._fc(x) + return y + + device = hapi.model.set_device('gpu') + fluid.enable_dygraph(device) + + model = MyModel() + inputs = [hapi.model.Input([None, 784], 'float32', name='x')] + model.prepare(inputs=inputs, + device=device) + data = np.random.random(size=(4,784)).astype(np.float32) + out = model.eval_batch([data]) + print(out) + """ return self._adapter.test_batch(*args, **kwargs) - def save(self, *args, **kwargs): + def save(self, path): + """ + This function saves parameters, optimizer infomation to path. + + The parameters contains all the trainable Variable, will save to + a file with suffix ".pdparams". + The optimizer information contains all the variable used by optimizer. + For Adam optimizer, contains beta1, beta2, momentum etc. All the + information will save to a file with suffix ".pdopt". (If the optimizer + have no variable need to save (like SGD), the fill will not generated). + + This function will silently overwrite existing file + at the target location. + + Args: + path (str): The file prefix to save model. The format is + 'dirname/file_prefix' or 'file_prefix'. if empty str. A exception + will be raised. + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + import hapi as hapi + + class MyModel(hapi.model.Model): + def __init__(self): + super(MyModel, self).__init__() + self._fc = fluid.dygraph.Linear(784, 1, act='softmax') + def forward(self, x): + y = self._fc(x) + return y + + device = hapi.model.set_device('cpu') + fluid.enable_dygraph(device) + model = MyModel() + model.save('checkpoint/test') + """ if ParallelEnv().local_rank == 0: - return self._adapter.save(*args, **kwargs) + return self._adapter.save(path) def load(self, path, skip_mismatch=False, reset_optimizer=False): """ @@ -780,6 +996,29 @@ class Model(fluid.dygraph.Layer): optimizer states and initialize optimizer states from scratch. Otherwise, restore optimizer states from `path.pdopt` if a optimizer has been set to the model. Default False. + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + import hapi as hapi + + class MyModel(hapi.model.Model): + def __init__(self): + super(MyModel, self).__init__() + self._fc = fluid.dygraph.Linear(784, 1, act='softmax') + def forward(self, x): + y = self._fc(x) + return y + + device = hapi.model.set_device('cpu') + fluid.enable_dygraph(device) + model = MyModel() + model.load('checkpoint/test') """ def _load_state_from_path(path): -- GitLab