提交 56e758fc 编写于 作者: T typhoonzero

trainer ok

上级 f35c5606
...@@ -56,6 +56,8 @@ def split_dense_variable(var_list, ...@@ -56,6 +56,8 @@ def split_dense_variable(var_list,
(block_id) * block_size)) (block_id) * block_size))
block = VarBlock(var.name, block_id, curr_block_size) block = VarBlock(var.name, block_id, curr_block_size)
blocks.append(str(block)) blocks.append(str(block))
print("$$ splited var: ", var.name, var.shape, split_count, len(blocks),
block_size)
return blocks return blocks
...@@ -132,10 +134,12 @@ class DistributeTranspiler: ...@@ -132,10 +134,12 @@ class DistributeTranspiler:
# step4 # step4
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname] orig_param = program.global_block().vars[varname]
concat = program.global_block().append_op( concat = program.global_block().append_op(
type="concat", type="concat",
inputs={"X": send_outputs}, inputs={"X": splited_var},
outputs={"Out": orig_param}, outputs={"Out": orig_param},
attrs={"axis": 0}) attrs={"axis": 0})
...@@ -147,28 +151,29 @@ class DistributeTranspiler: ...@@ -147,28 +151,29 @@ class DistributeTranspiler:
if not block_map.has_key(varname): if not block_map.has_key(varname):
block_map[varname] = [] block_map[varname] = []
block_map[varname].append((long(offset), long(size))) block_map[varname].append((long(offset), long(size)))
for varname, splited in block_map.iteritems(): for varname, splited in block_map.iteritems():
orig_var = program.global_block().vars[varname] orig_var = program.global_block().vars[varname]
var_mapping[varname] = []
if len(splited) == 1:
var_mapping[varname] = [orig_var]
continue
orig_shape = orig_var.shape orig_shape = orig_var.shape
orig_dim1_flatten = 1 orig_dim1_flatten = 1
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
orig_dim1_flatten = reduce(lambda x, y: x * y, orig_shape[1:]) orig_dim1_flatten = reduce(lambda x, y: x * y, orig_shape[1:])
var_list = []
for i, block in enumerate(splited): for i, block in enumerate(splited):
size = block[1] size = block[1]
rows = size / orig_dim1_flatten rows = size / orig_dim1_flatten
splited_shape = [rows] splited_shape = [rows]
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:]) splited_shape.extend(orig_shape[1:])
print("block, splited shape:", block, splited_shape)
var = program.global_block().create_var( var = program.global_block().create_var(
name="%s.block%d" % (varname, i), name="%s.block%d" % (varname, i),
psersistable=False, psersistable=False,
dtype=orig_var.dtype, dtype=orig_var.dtype,
shape=splited_shape) # flattend splited var shape=splited_shape) # flattend splited var
var_list.append(var) var_mapping[varname].append(var)
var_mapping[varname] = var_list
return var_mapping return var_mapping
def _clone_param(self, block, v): def _clone_param(self, block, v):
...@@ -199,7 +204,8 @@ class DistributeTranspiler: ...@@ -199,7 +204,8 @@ class DistributeTranspiler:
def _append_split_op(self, program, gradblocks): def _append_split_op(self, program, gradblocks):
var_mapping = self._create_vars_from_blocklist(program, gradblocks) var_mapping = self._create_vars_from_blocklist(program, gradblocks)
for varname, splited_vars in var_mapping.iteritems(): for varname, splited_vars in var_mapping.iteritems():
if len(splited_vars) == 1: # variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1:
continue continue
orig_var = program.global_block().vars[varname] orig_var = program.global_block().vars[varname]
sections = [] sections = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册