提交 45df118a 编写于 作者: L LielinJiang

Env->ParalleEnv

上级 9a710229
...@@ -16,7 +16,7 @@ import six ...@@ -16,7 +16,7 @@ import six
import copy import copy
from progressbar import ProgressBar from progressbar import ProgressBar
from paddle.fluid.dygraph.parallel import Env from paddle.fluid.dygraph.parallel import ParallelEnv
def config_callbacks(callbacks=None, def config_callbacks(callbacks=None,
...@@ -195,7 +195,7 @@ class ProgBarLogger(Callback): ...@@ -195,7 +195,7 @@ class ProgBarLogger(Callback):
self.steps = self.params['steps'] self.steps = self.params['steps']
self.epoch = epoch self.epoch = epoch
self.train_step = 0 self.train_step = 0
if self.verbose and self.epochs and Env().local_rank == 0: if self.verbose and self.epochs and ParallelEnv().local_rank == 0:
print('Epoch %d/%d' % (epoch + 1, self.epochs)) print('Epoch %d/%d' % (epoch + 1, self.epochs))
self.train_progbar = ProgressBar(num=self.steps, verbose=self.verbose) self.train_progbar = ProgressBar(num=self.steps, verbose=self.verbose)
...@@ -213,7 +213,8 @@ class ProgBarLogger(Callback): ...@@ -213,7 +213,8 @@ class ProgBarLogger(Callback):
logs = logs or {} logs = logs or {}
self.train_step += 1 self.train_step += 1
if self.train_step % self.log_freq == 0 and self.verbose and Env().local_rank == 0: if self.train_step % self.log_freq == 0 and self.verbose and ParallelEnv(
).local_rank == 0:
# if steps is not None, last step will update in on_epoch_end # if steps is not None, last step will update in on_epoch_end
if self.steps and self.train_step < self.steps: if self.steps and self.train_step < self.steps:
self._updates(logs, 'train') self._updates(logs, 'train')
...@@ -222,7 +223,7 @@ class ProgBarLogger(Callback): ...@@ -222,7 +223,7 @@ class ProgBarLogger(Callback):
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
logs = logs or {} logs = logs or {}
if self.verbose and Env().local_rank == 0: if self.verbose and ParallelEnv().local_rank == 0:
self._updates(logs, 'train') self._updates(logs, 'train')
def on_eval_begin(self, logs=None): def on_eval_begin(self, logs=None):
...@@ -232,7 +233,7 @@ class ProgBarLogger(Callback): ...@@ -232,7 +233,7 @@ class ProgBarLogger(Callback):
self.evaled_samples = 0 self.evaled_samples = 0
self.eval_progbar = ProgressBar( self.eval_progbar = ProgressBar(
num=self.eval_steps, verbose=self.verbose) num=self.eval_steps, verbose=self.verbose)
if Env().local_rank == 0: if ParallelEnv().local_rank == 0:
print('Eval begin...') print('Eval begin...')
def on_eval_batch_end(self, step, logs=None): def on_eval_batch_end(self, step, logs=None):
...@@ -243,7 +244,7 @@ class ProgBarLogger(Callback): ...@@ -243,7 +244,7 @@ class ProgBarLogger(Callback):
def on_eval_end(self, logs=None): def on_eval_end(self, logs=None):
logs = logs or {} logs = logs or {}
if self.verbose and Env().local_rank == 0: if self.verbose and ParallelEnv().local_rank == 0:
self._updates(logs, 'eval') self._updates(logs, 'eval')
print('Eval samples: %d' % (self.evaled_samples)) print('Eval samples: %d' % (self.evaled_samples))
...@@ -257,7 +258,7 @@ class ModelCheckpoint(Callback): ...@@ -257,7 +258,7 @@ class ModelCheckpoint(Callback):
self.epoch = epoch self.epoch = epoch
def _is_save(self): def _is_save(self):
return self.model and self.save_dir and Env().local_rank == 0 return self.model and self.save_dir and ParallelEnv().local_rank == 0
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
if self._is_save() and self.epoch % self.save_freq == 0: if self._is_save() and self.epoch % self.save_freq == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册