From 87eb929f7d62b6b0ef209ec88ffb9ea07e584d42 Mon Sep 17 00:00:00 2001 From: lyuwenyu Date: Thu, 17 Jun 2021 14:36:42 +0800 Subject: [PATCH] add gradient accumulate for dygraph --- python/paddle/hapi/model.py | 43 ++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index c9b6c0098e..0cfe98cd5c 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -701,7 +701,7 @@ class DynamicGraphAdapter(object): self.model.mode = value # TODO multi device in dygraph mode not implemented at present time - def train_batch(self, inputs, labels=None): + def train_batch(self, inputs, labels=None, update=True): assert self.model._optimizer, \ "model not ready, please call `model.prepare()` first" self.model.network.train() @@ -729,13 +729,15 @@ class DynamicGraphAdapter(object): if self._amp_level != "O0": scaled = scaler.scale(final_loss) scaled.backward() - scaler.minimize(self.model._optimizer, scaled) - self.model.network.clear_gradients() + if update: + scaler.minimize(self.model._optimizer, scaled) + self.model.network.clear_gradients() else: final_loss.backward() - self.model._optimizer.minimize(final_loss) - self.model.network.clear_gradients() - + if update: + self.model._optimizer.minimize(final_loss) + self.model.network.clear_gradients() + metrics = [] for metric in self.model._metrics: metric_outs = metric.compute(*(to_list(outputs) + labels)) @@ -1017,7 +1019,7 @@ class Model(object): else: self._adapter = StaticGraphAdapter(self) - def train_batch(self, inputs, labels=None): + def train_batch(self, inputs, labels=None, update=True): """ Run one training step on a batch of data. @@ -1062,7 +1064,7 @@ class Model(object): loss = model.train_batch([data], [label]) print(loss) """ - loss = self._adapter.train_batch(inputs, labels) + loss = self._adapter.train_batch(inputs, labels, update) if fluid.in_dygraph_mode() and self._input_info is None: self._update_inputs() return loss @@ -1536,7 +1538,8 @@ class Model(object): drop_last=False, shuffle=True, num_workers=0, - callbacks=None, ): + callbacks=None, + accumulate=1, ): """ Trains the model for a fixed number of epochs. If `eval_data` is set, evaluation will be done at the end of each epoch. @@ -1579,7 +1582,8 @@ class Model(object): callbacks (Callback|None): A list of `Callback` instances to apply during training. If None, `ProgBarLogger` and `ModelCheckpoint` are automatically inserted. Default: None. - + accumulate (int): The number of steps to accumulate gradident in training process before optimizer update. Using this to mimic large batch size. Default: 1. + Returns: None @@ -1699,7 +1703,8 @@ class Model(object): do_eval = eval_loader is not None self._test_dataloader = eval_loader - + self._accumulate = accumulate + steps = self._len_data_loader(train_loader) cbks = config_callbacks( callbacks, @@ -1737,7 +1742,7 @@ class Model(object): cbks.on_end('train', logs) self._test_dataloader = None - + def evaluate( self, eval_data, @@ -2004,7 +2009,7 @@ class Model(object): model_filename=model_filename, params_filename=params_filename) - def _run_one_epoch(self, data_loader, callbacks, mode, logs={}): + def _run_one_epoch(self, data_loader, callbacks, mode, logs={},): outputs = [] for step, data in enumerate(data_loader): # data might come from different types of data_loader and have @@ -2028,8 +2033,16 @@ class Model(object): callbacks.on_batch_begin(mode, step, logs) if mode != 'predict': - outs = getattr(self, mode + '_batch')(data[:len(self._inputs)], - data[len(self._inputs):]) + + _inputs = [data[:len(self._inputs)], data[len(self._inputs):]] + if mode == 'train': + _inputs.append((step + 1) % self._accumulate == 0) + + outs = getattr(self, mode + '_batch')(*_inputs) + + # outs = getattr(self, mode + '_batch')(data[:len(self._inputs)], + # data[len(self._inputs):]) + if self._metrics and self._loss: metrics = [[l[0] for l in outs[0]]] elif self._loss: -- GitLab