提交 5fa1f05c 编写于 作者: T tangwei12

code clean and rename

上级 bf2f1599
......@@ -1366,7 +1366,7 @@ class Program(object):
# for distribute
self._is_distributed = False
self._is_chief = False
self._slice_vars_and_atts = []
self._slice_vars_and_attrs = []
self._endpoints = []
self._distributed_lookup_table = None
......
......@@ -641,21 +641,21 @@ class TestLoadSliceVar(TranspilerTest):
pserver, _ = self.get_pserver(self.pserver1_ep)
pserver2, _ = self.get_pserver(self.pserver2_ep)
self.assertTrue(pserver._slice_vars_and_atts)
self.assertTrue(pserver2._slice_vars_and_atts)
self.assertTrue(pserver._slice_vars_and_attrs)
self.assertTrue(pserver2._slice_vars_and_attrs)
for idx in xrange(len(pserver._slice_vars_and_atts)):
self.assertEqual(pserver._slice_vars_and_atts[idx][0],
pserver2._slice_vars_and_atts[idx][0])
for idx in xrange(len(pserver._slice_vars_and_attrs)):
self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_attrs[idx][0])
total_numel = reduce(lambda x, y: x * y,
pserver._slice_vars_and_atts[idx][0].shape)
pserver._slice_vars_and_attrs[idx][0].shape)
self.assertEqual(
total_numel,
reduce(lambda x, y: x * y,
pserver._slice_vars_and_atts[idx][2].shape) + reduce(
pserver._slice_vars_and_attrs[idx][2].shape) + reduce(
lambda x, y: x * y,
pserver2._slice_vars_and_atts[idx][2].shape))
pserver2._slice_vars_and_attrs[idx][2].shape))
if __name__ == "__main__":
......
......@@ -626,7 +626,7 @@ class DistributeTranspiler(object):
attrs=attrs)
# add distributed attrs
pserver_program._slice_vars_and_atts = self._get_slice_vars_and_atts(
pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs(
endpoint)
pserver_program._sync_with_cpp()
......@@ -704,31 +704,28 @@ class DistributeTranspiler(object):
attrs=op.all_attrs())
# add slice vars
s_prog._slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint)
s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint)
return s_prog
def _get_slice_vars_and_atts(self, endpoint):
slice_vars_and_atts = []
block_suffix = ".block"
def _get_slice_vars_and_attrs(self, endpoint):
slice_vars_and_attrs = []
block_suffix = "block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
suff_idx = param.name.find(block_suffix)
if suff_idx <= 0:
orig_var_name, block_name, _ = self._get_varname_parts(param)
if not block_name:
continue
orig_var_name = param.name[:suff_idx]
block_idx = int(param.name[suff_idx + len(block_suffix):])
block_idx = int(block_name.split(block_suffix)[1])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_atts.append([orig_var, skip_numel, param])
slice_vars_and_attrs.append([orig_var, skip_numel, param])
return slice_vars_and_atts
return slice_vars_and_attrs
# ====================== private transpiler functions =====================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册