提交 5fa1f05c 编写于 作者: T tangwei12

code clean and rename

上级 bf2f1599
...@@ -1366,7 +1366,7 @@ class Program(object): ...@@ -1366,7 +1366,7 @@ class Program(object):
# for distribute # for distribute
self._is_distributed = False self._is_distributed = False
self._is_chief = False self._is_chief = False
self._slice_vars_and_atts = [] self._slice_vars_and_attrs = []
self._endpoints = [] self._endpoints = []
self._distributed_lookup_table = None self._distributed_lookup_table = None
......
...@@ -641,21 +641,21 @@ class TestLoadSliceVar(TranspilerTest): ...@@ -641,21 +641,21 @@ class TestLoadSliceVar(TranspilerTest):
pserver, _ = self.get_pserver(self.pserver1_ep) pserver, _ = self.get_pserver(self.pserver1_ep)
pserver2, _ = self.get_pserver(self.pserver2_ep) pserver2, _ = self.get_pserver(self.pserver2_ep)
self.assertTrue(pserver._slice_vars_and_atts) self.assertTrue(pserver._slice_vars_and_attrs)
self.assertTrue(pserver2._slice_vars_and_atts) self.assertTrue(pserver2._slice_vars_and_attrs)
for idx in xrange(len(pserver._slice_vars_and_atts)): for idx in xrange(len(pserver._slice_vars_and_attrs)):
self.assertEqual(pserver._slice_vars_and_atts[idx][0], self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_atts[idx][0]) pserver2._slice_vars_and_attrs[idx][0])
total_numel = reduce(lambda x, y: x * y, total_numel = reduce(lambda x, y: x * y,
pserver._slice_vars_and_atts[idx][0].shape) pserver._slice_vars_and_attrs[idx][0].shape)
self.assertEqual( self.assertEqual(
total_numel, total_numel,
reduce(lambda x, y: x * y, reduce(lambda x, y: x * y,
pserver._slice_vars_and_atts[idx][2].shape) + reduce( pserver._slice_vars_and_attrs[idx][2].shape) + reduce(
lambda x, y: x * y, lambda x, y: x * y,
pserver2._slice_vars_and_atts[idx][2].shape)) pserver2._slice_vars_and_attrs[idx][2].shape))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -626,7 +626,7 @@ class DistributeTranspiler(object): ...@@ -626,7 +626,7 @@ class DistributeTranspiler(object):
attrs=attrs) attrs=attrs)
# add distributed attrs # add distributed attrs
pserver_program._slice_vars_and_atts = self._get_slice_vars_and_atts( pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs(
endpoint) endpoint)
pserver_program._sync_with_cpp() pserver_program._sync_with_cpp()
...@@ -704,31 +704,28 @@ class DistributeTranspiler(object): ...@@ -704,31 +704,28 @@ class DistributeTranspiler(object):
attrs=op.all_attrs()) attrs=op.all_attrs())
# add slice vars # add slice vars
s_prog._slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint) s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint)
return s_prog return s_prog
def _get_slice_vars_and_atts(self, endpoint): def _get_slice_vars_and_attrs(self, endpoint):
slice_vars_and_atts = [] slice_vars_and_attrs = []
block_suffix = ".block" block_suffix = "block"
for param in self.param_grad_ep_mapping[endpoint]["params"]: for param in self.param_grad_ep_mapping[endpoint]["params"]:
orig_var_name, block_name, _ = self._get_varname_parts(param)
suff_idx = param.name.find(block_suffix) if not block_name:
if suff_idx <= 0:
continue continue
orig_var_name = param.name[:suff_idx] block_idx = int(block_name.split(block_suffix)[1])
block_idx = int(param.name[suff_idx + len(block_suffix):])
orig_var = self.origin_program.global_block().vars[orig_var_name] orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0 skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name] slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]: for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape) skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_atts.append([orig_var, skip_numel, param]) slice_vars_and_attrs.append([orig_var, skip_numel, param])
return slice_vars_and_atts return slice_vars_and_attrs
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册