提交 9a710229 编写于 作者: L LielinJiang

format code

上级 1b798365
......@@ -403,8 +403,10 @@ class StaticGraphAdapter(object):
metrics = []
with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict):
ins = [self.model._inputs[n]
for n in extract_args(self.model.forward) if n != 'self']
ins = [
self.model._inputs[n]
for n in extract_args(self.model.forward) if n != 'self'
]
else:
ins = self.model._inputs
lbls = self.model._labels if self.model._labels else []
......@@ -470,8 +472,8 @@ class StaticGraphAdapter(object):
if self._executor is None:
if self._nranks > 1 and device.lower() == 'gpu':
gpu_id = int(ParallelEnv().dev_id)
place = fluid.CUDAPlace(
gpu_id) if device.lower() == 'gpu' else fluid.CPUPlace()
place = fluid.CUDAPlace(gpu_id) if device.lower(
) == 'gpu' else fluid.CPUPlace()
else:
place = places[0]
self._executor = fluid.Executor(place)
......@@ -521,8 +523,8 @@ class DynamicGraphAdapter(object):
stradegy.local_rank = ParallelEnv().local_rank
stradegy.trainer_endpoints = ParallelEnv().trainer_endpoints
stradegy.current_endpoint = ParallelEnv().current_endpoint
self.ddp_model = fluid.dygraph.parallel.DataParallel(
self.model, stradegy)
self.ddp_model = fluid.dygraph.parallel.DataParallel(self.model,
stradegy)
@property
def mode(self):
......@@ -1017,7 +1019,8 @@ class Model(fluid.dygraph.Layer):
logs[k] = v
logs['step'] = step
if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0:
if mode == 'train' or self._adapter._merge_count.get(
mode + '_batch', 0) <= 0:
logs['batch_size'] = batch_size * ParallelEnv().nranks
else:
logs['batch_size'] = self._adapter._merge_count[mode +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册