提交 56e758fc 编写于 作者: T typhoonzero

trainer ok

上级 f35c5606
......@@ -56,6 +56,8 @@ def split_dense_variable(var_list,
(block_id) * block_size))
block = VarBlock(var.name, block_id, curr_block_size)
blocks.append(str(block))
print("$$ splited var: ", var.name, var.shape, split_count, len(blocks),
block_size)
return blocks
......@@ -132,10 +134,12 @@ class DistributeTranspiler:
# step4
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
concat = program.global_block().append_op(
type="concat",
inputs={"X": send_outputs},
inputs={"X": splited_var},
outputs={"Out": orig_param},
attrs={"axis": 0})
......@@ -147,28 +151,29 @@ class DistributeTranspiler:
if not block_map.has_key(varname):
block_map[varname] = []
block_map[varname].append((long(offset), long(size)))
for varname, splited in block_map.iteritems():
orig_var = program.global_block().vars[varname]
var_mapping[varname] = []
if len(splited) == 1:
var_mapping[varname] = [orig_var]
continue
orig_shape = orig_var.shape
orig_dim1_flatten = 1
if len(orig_shape) >= 2:
orig_dim1_flatten = reduce(lambda x, y: x * y, orig_shape[1:])
var_list = []
for i, block in enumerate(splited):
size = block[1]
rows = size / orig_dim1_flatten
splited_shape = [rows]
if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:])
print("block, splited shape:", block, splited_shape)
var = program.global_block().create_var(
name="%s.block%d" % (varname, i),
psersistable=False,
dtype=orig_var.dtype,
shape=splited_shape) # flattend splited var
var_list.append(var)
var_mapping[varname] = var_list
var_mapping[varname].append(var)
return var_mapping
def _clone_param(self, block, v):
......@@ -199,7 +204,8 @@ class DistributeTranspiler:
def _append_split_op(self, program, gradblocks):
var_mapping = self._create_vars_from_blocklist(program, gradblocks)
for varname, splited_vars in var_mapping.iteritems():
if len(splited_vars) == 1:
# variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1:
continue
orig_var = program.global_block().vars[varname]
sections = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册