From 45df118a62b9cddf60dbff43034ba11f57044a34 Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Wed, 25 Mar 2020 03:58:03 +0000 Subject: [PATCH] Env->ParalleEnv --- callbacks.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/callbacks.py b/callbacks.py index e4b3fbb..a055940 100644 --- a/callbacks.py +++ b/callbacks.py @@ -16,7 +16,7 @@ import six import copy from progressbar import ProgressBar -from paddle.fluid.dygraph.parallel import Env +from paddle.fluid.dygraph.parallel import ParallelEnv def config_callbacks(callbacks=None, @@ -195,7 +195,7 @@ class ProgBarLogger(Callback): self.steps = self.params['steps'] self.epoch = epoch self.train_step = 0 - if self.verbose and self.epochs and Env().local_rank == 0: + if self.verbose and self.epochs and ParallelEnv().local_rank == 0: print('Epoch %d/%d' % (epoch + 1, self.epochs)) self.train_progbar = ProgressBar(num=self.steps, verbose=self.verbose) @@ -213,7 +213,8 @@ class ProgBarLogger(Callback): logs = logs or {} self.train_step += 1 - if self.train_step % self.log_freq == 0 and self.verbose and Env().local_rank == 0: + if self.train_step % self.log_freq == 0 and self.verbose and ParallelEnv( + ).local_rank == 0: # if steps is not None, last step will update in on_epoch_end if self.steps and self.train_step < self.steps: self._updates(logs, 'train') @@ -222,7 +223,7 @@ class ProgBarLogger(Callback): def on_epoch_end(self, epoch, logs=None): logs = logs or {} - if self.verbose and Env().local_rank == 0: + if self.verbose and ParallelEnv().local_rank == 0: self._updates(logs, 'train') def on_eval_begin(self, logs=None): @@ -232,7 +233,7 @@ class ProgBarLogger(Callback): self.evaled_samples = 0 self.eval_progbar = ProgressBar( num=self.eval_steps, verbose=self.verbose) - if Env().local_rank == 0: + if ParallelEnv().local_rank == 0: print('Eval begin...') def on_eval_batch_end(self, step, logs=None): @@ -243,7 +244,7 @@ class ProgBarLogger(Callback): def on_eval_end(self, logs=None): logs = logs or {} - if self.verbose and Env().local_rank == 0: + if self.verbose and ParallelEnv().local_rank == 0: self._updates(logs, 'eval') print('Eval samples: %d' % (self.evaled_samples)) @@ -257,7 +258,7 @@ class ModelCheckpoint(Callback): self.epoch = epoch def _is_save(self): - return self.model and self.save_dir and Env().local_rank == 0 + return self.model and self.save_dir and ParallelEnv().local_rank == 0 def on_epoch_end(self, epoch, logs=None): if self._is_save() and self.epoch % self.save_freq == 0: -- GitLab