From ae2b218576b41aeb58c9be8144b1f08e82c4d81b Mon Sep 17 00:00:00 2001 From: lyuwenyu Date: Thu, 17 Jun 2021 15:53:05 +0800 Subject: [PATCH] update --- python/paddle/hapi/model.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 0cfe98cd5c9..d2493a74744 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -737,7 +737,7 @@ class DynamicGraphAdapter(object): if update: self.model._optimizer.minimize(final_loss) self.model.network.clear_gradients() - + metrics = [] for metric in self.model._metrics: metric_outs = metric.compute(*(to_list(outputs) + labels)) @@ -1538,7 +1538,7 @@ class Model(object): drop_last=False, shuffle=True, num_workers=0, - callbacks=None, + callbacks=None, accumulate=1, ): """ Trains the model for a fixed number of epochs. If `eval_data` is set, @@ -1703,8 +1703,9 @@ class Model(object): do_eval = eval_loader is not None self._test_dataloader = eval_loader + self._accumulate = accumulate - + steps = self._len_data_loader(train_loader) cbks = config_callbacks( callbacks, @@ -1742,7 +1743,7 @@ class Model(object): cbks.on_end('train', logs) self._test_dataloader = None - + def evaluate( self, eval_data, @@ -2009,7 +2010,12 @@ class Model(object): model_filename=model_filename, params_filename=params_filename) - def _run_one_epoch(self, data_loader, callbacks, mode, logs={},): + def _run_one_epoch( + self, + data_loader, + callbacks, + mode, + logs={}, ): outputs = [] for step, data in enumerate(data_loader): # data might come from different types of data_loader and have @@ -2033,13 +2039,13 @@ class Model(object): callbacks.on_batch_begin(mode, step, logs) if mode != 'predict': - + _inputs = [data[:len(self._inputs)], data[len(self._inputs):]] if mode == 'train': _inputs.append((step + 1) % self._accumulate == 0) - + outs = getattr(self, mode + '_batch')(*_inputs) - + # outs = getattr(self, mode + '_batch')(data[:len(self._inputs)], # data[len(self._inputs):]) -- GitLab