From 9a710229676f267eb5a9a60b75bf971bbf7428bc Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Wed, 25 Mar 2020 03:54:50 +0000 Subject: [PATCH] format code --- model.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/model.py b/model.py index 090a323..6f4f3b4 100644 --- a/model.py +++ b/model.py @@ -403,8 +403,10 @@ class StaticGraphAdapter(object): metrics = [] with fluid.program_guard(prog, self._startup_prog): if isinstance(self.model._inputs, dict): - ins = [self.model._inputs[n] - for n in extract_args(self.model.forward) if n != 'self'] + ins = [ + self.model._inputs[n] + for n in extract_args(self.model.forward) if n != 'self' + ] else: ins = self.model._inputs lbls = self.model._labels if self.model._labels else [] @@ -470,8 +472,8 @@ class StaticGraphAdapter(object): if self._executor is None: if self._nranks > 1 and device.lower() == 'gpu': gpu_id = int(ParallelEnv().dev_id) - place = fluid.CUDAPlace( - gpu_id) if device.lower() == 'gpu' else fluid.CPUPlace() + place = fluid.CUDAPlace(gpu_id) if device.lower( + ) == 'gpu' else fluid.CPUPlace() else: place = places[0] self._executor = fluid.Executor(place) @@ -521,8 +523,8 @@ class DynamicGraphAdapter(object): stradegy.local_rank = ParallelEnv().local_rank stradegy.trainer_endpoints = ParallelEnv().trainer_endpoints stradegy.current_endpoint = ParallelEnv().current_endpoint - self.ddp_model = fluid.dygraph.parallel.DataParallel( - self.model, stradegy) + self.ddp_model = fluid.dygraph.parallel.DataParallel(self.model, + stradegy) @property def mode(self): @@ -1017,7 +1019,8 @@ class Model(fluid.dygraph.Layer): logs[k] = v logs['step'] = step - if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0: + if mode == 'train' or self._adapter._merge_count.get( + mode + '_batch', 0) <= 0: logs['batch_size'] = batch_size * ParallelEnv().nranks else: logs['batch_size'] = self._adapter._merge_count[mode + -- GitLab