From 5fa1f05c9d7efb06c2495e8321164bf8ed750a4c Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 20 Aug 2018 12:48:07 +0800 Subject: [PATCH] code clean and rename --- python/paddle/fluid/framework.py | 2 +- .../tests/unittests/test_dist_transpiler.py | 16 ++++++------- .../fluid/transpiler/distribute_transpiler.py | 23 ++++++++----------- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index bb5b4d5360..62682d1032 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1366,7 +1366,7 @@ class Program(object): # for distribute self._is_distributed = False self._is_chief = False - self._slice_vars_and_atts = [] + self._slice_vars_and_attrs = [] self._endpoints = [] self._distributed_lookup_table = None diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 1889ddad1f..9f04d290f7 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -641,21 +641,21 @@ class TestLoadSliceVar(TranspilerTest): pserver, _ = self.get_pserver(self.pserver1_ep) pserver2, _ = self.get_pserver(self.pserver2_ep) - self.assertTrue(pserver._slice_vars_and_atts) - self.assertTrue(pserver2._slice_vars_and_atts) + self.assertTrue(pserver._slice_vars_and_attrs) + self.assertTrue(pserver2._slice_vars_and_attrs) - for idx in xrange(len(pserver._slice_vars_and_atts)): - self.assertEqual(pserver._slice_vars_and_atts[idx][0], - pserver2._slice_vars_and_atts[idx][0]) + for idx in xrange(len(pserver._slice_vars_and_attrs)): + self.assertEqual(pserver._slice_vars_and_attrs[idx][0], + pserver2._slice_vars_and_attrs[idx][0]) total_numel = reduce(lambda x, y: x * y, - pserver._slice_vars_and_atts[idx][0].shape) + pserver._slice_vars_and_attrs[idx][0].shape) self.assertEqual( total_numel, reduce(lambda x, y: x * y, - pserver._slice_vars_and_atts[idx][2].shape) + reduce( + pserver._slice_vars_and_attrs[idx][2].shape) + reduce( lambda x, y: x * y, - pserver2._slice_vars_and_atts[idx][2].shape)) + pserver2._slice_vars_and_attrs[idx][2].shape)) if __name__ == "__main__": diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 4856442ac7..5cc447f1dd 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -626,7 +626,7 @@ class DistributeTranspiler(object): attrs=attrs) # add distributed attrs - pserver_program._slice_vars_and_atts = self._get_slice_vars_and_atts( + pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs( endpoint) pserver_program._sync_with_cpp() @@ -704,31 +704,28 @@ class DistributeTranspiler(object): attrs=op.all_attrs()) # add slice vars - s_prog._slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint) + s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint) return s_prog - def _get_slice_vars_and_atts(self, endpoint): - slice_vars_and_atts = [] - block_suffix = ".block" + def _get_slice_vars_and_attrs(self, endpoint): + slice_vars_and_attrs = [] + block_suffix = "block" for param in self.param_grad_ep_mapping[endpoint]["params"]: - - suff_idx = param.name.find(block_suffix) - if suff_idx <= 0: + orig_var_name, block_name, _ = self._get_varname_parts(param) + if not block_name: continue - orig_var_name = param.name[:suff_idx] - block_idx = int(param.name[suff_idx + len(block_suffix):]) - + block_idx = int(block_name.split(block_suffix)[1]) orig_var = self.origin_program.global_block().vars[orig_var_name] skip_numel = 0 slice_vars = self.param_var_mapping[orig_var_name] for slice_var in slice_vars[:block_idx]: skip_numel += reduce(lambda x, y: x * y, slice_var.shape) - slice_vars_and_atts.append([orig_var, skip_numel, param]) + slice_vars_and_attrs.append([orig_var, skip_numel, param]) - return slice_vars_and_atts + return slice_vars_and_attrs # ====================== private transpiler functions ===================== -- GitLab