From e677a23d46eeed3923656e43cf75be420fd75b93 Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Thu, 30 Apr 2020 03:09:44 +0000 Subject: [PATCH] refine code --- hapi/callbacks.py | 2 +- hapi/model.py | 21 +++++++-------------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/hapi/callbacks.py b/hapi/callbacks.py index bf6d565..f7bb878 100644 --- a/hapi/callbacks.py +++ b/hapi/callbacks.py @@ -296,7 +296,7 @@ class ProgBarLogger(Callback): self.tested_samples = 0 self.test_progbar = ProgressBar( num=self.test_steps, verbose=self.verbose) - if ParallelEnv().local_rank == 0: + if self._is_print(): print('Predict begin...') def on_test_batch_end(self, step, logs=None): diff --git a/hapi/model.py b/hapi/model.py index 94ac623..083d2ee 100644 --- a/hapi/model.py +++ b/hapi/model.py @@ -1303,22 +1303,19 @@ class Model(fluid.dygraph.Layer): cbks.on_begin('train') for epoch in range(epochs): - loader = train_loader - cbks.on_epoch_begin(epoch) - logs = self._run_one_epoch(loader, cbks, 'train') + logs = self._run_one_epoch(train_loader, cbks, 'train') cbks.on_epoch_end(epoch, logs) if do_eval and epoch % eval_freq == 0: - loader = eval_loader - eval_steps = self._len_data_loader(loader) + eval_steps = self._len_data_loader(eval_loader) cbks.on_begin('eval', { 'steps': eval_steps, 'metrics_name': self._metrics_name() }) - logs = self._run_one_epoch(loader, cbks, 'eval') + logs = self._run_one_epoch(eval_loader, cbks, 'eval') cbks.on_end('eval', logs) @@ -1414,15 +1411,13 @@ class Model(fluid.dygraph.Layer): verbose=verbose, metrics=self._metrics_name(), ) - loader = eval_loader - - eval_steps = self._len_data_loader(loader) + eval_steps = self._len_data_loader(eval_loader) cbks.on_begin('eval', { 'steps': eval_steps, 'metrics_name': self._metrics_name() }) - logs = self._run_one_epoch(loader, cbks, 'eval') + logs = self._run_one_epoch(eval_loader, cbks, 'eval') cbks.on_end('eval', logs) @@ -1522,18 +1517,16 @@ class Model(fluid.dygraph.Layer): self._test_dataloader = test_loader - loader = test_loader - cbks = config_callbacks(callbacks, model=self, verbose=1) - test_steps = self._len_data_loader(loader) + test_steps = self._len_data_loader(test_loader) logs = {'steps': test_steps} cbks.on_begin('test', logs) outputs = [] - logs, outputs = self._run_one_epoch(loader, cbks, 'test') + logs, outputs = self._run_one_epoch(test_loader, cbks, 'test') outputs = list(zip(*outputs)) -- GitLab