提交 9a710229 编写于 作者: L LielinJiang

format code

上级 1b798365
...@@ -403,8 +403,10 @@ class StaticGraphAdapter(object): ...@@ -403,8 +403,10 @@ class StaticGraphAdapter(object):
metrics = [] metrics = []
with fluid.program_guard(prog, self._startup_prog): with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict): if isinstance(self.model._inputs, dict):
ins = [self.model._inputs[n] ins = [
for n in extract_args(self.model.forward) if n != 'self'] self.model._inputs[n]
for n in extract_args(self.model.forward) if n != 'self'
]
else: else:
ins = self.model._inputs ins = self.model._inputs
lbls = self.model._labels if self.model._labels else [] lbls = self.model._labels if self.model._labels else []
...@@ -470,8 +472,8 @@ class StaticGraphAdapter(object): ...@@ -470,8 +472,8 @@ class StaticGraphAdapter(object):
if self._executor is None: if self._executor is None:
if self._nranks > 1 and device.lower() == 'gpu': if self._nranks > 1 and device.lower() == 'gpu':
gpu_id = int(ParallelEnv().dev_id) gpu_id = int(ParallelEnv().dev_id)
place = fluid.CUDAPlace( place = fluid.CUDAPlace(gpu_id) if device.lower(
gpu_id) if device.lower() == 'gpu' else fluid.CPUPlace() ) == 'gpu' else fluid.CPUPlace()
else: else:
place = places[0] place = places[0]
self._executor = fluid.Executor(place) self._executor = fluid.Executor(place)
...@@ -521,8 +523,8 @@ class DynamicGraphAdapter(object): ...@@ -521,8 +523,8 @@ class DynamicGraphAdapter(object):
stradegy.local_rank = ParallelEnv().local_rank stradegy.local_rank = ParallelEnv().local_rank
stradegy.trainer_endpoints = ParallelEnv().trainer_endpoints stradegy.trainer_endpoints = ParallelEnv().trainer_endpoints
stradegy.current_endpoint = ParallelEnv().current_endpoint stradegy.current_endpoint = ParallelEnv().current_endpoint
self.ddp_model = fluid.dygraph.parallel.DataParallel( self.ddp_model = fluid.dygraph.parallel.DataParallel(self.model,
self.model, stradegy) stradegy)
@property @property
def mode(self): def mode(self):
...@@ -1017,7 +1019,8 @@ class Model(fluid.dygraph.Layer): ...@@ -1017,7 +1019,8 @@ class Model(fluid.dygraph.Layer):
logs[k] = v logs[k] = v
logs['step'] = step logs['step'] = step
if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0: if mode == 'train' or self._adapter._merge_count.get(
mode + '_batch', 0) <= 0:
logs['batch_size'] = batch_size * ParallelEnv().nranks logs['batch_size'] = batch_size * ParallelEnv().nranks
else: else:
logs['batch_size'] = self._adapter._merge_count[mode + logs['batch_size'] = self._adapter._merge_count[mode +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册