diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 58d32bac1257ab25f406ab37948d532b84d4ded8..7f3da674633884d5a275a952c07d5a4a8e0f138c 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -56,6 +56,8 @@ def split_dense_variable(var_list, (block_id) * block_size)) block = VarBlock(var.name, block_id, curr_block_size) blocks.append(str(block)) + print("$$ splited var: ", var.name, var.shape, split_count, len(blocks), + block_size) return blocks @@ -132,10 +134,12 @@ class DistributeTranspiler: # step4 for varname, splited_var in param_var_mapping.iteritems(): + if len(splited_var) <= 1: + continue orig_param = program.global_block().vars[varname] concat = program.global_block().append_op( type="concat", - inputs={"X": send_outputs}, + inputs={"X": splited_var}, outputs={"Out": orig_param}, attrs={"axis": 0}) @@ -147,28 +151,29 @@ class DistributeTranspiler: if not block_map.has_key(varname): block_map[varname] = [] block_map[varname].append((long(offset), long(size))) - for varname, splited in block_map.iteritems(): orig_var = program.global_block().vars[varname] + var_mapping[varname] = [] + if len(splited) == 1: + var_mapping[varname] = [orig_var] + continue orig_shape = orig_var.shape orig_dim1_flatten = 1 if len(orig_shape) >= 2: orig_dim1_flatten = reduce(lambda x, y: x * y, orig_shape[1:]) - var_list = [] + for i, block in enumerate(splited): size = block[1] rows = size / orig_dim1_flatten splited_shape = [rows] if len(orig_shape) >= 2: splited_shape.extend(orig_shape[1:]) - print("block, splited shape:", block, splited_shape) var = program.global_block().create_var( name="%s.block%d" % (varname, i), psersistable=False, dtype=orig_var.dtype, shape=splited_shape) # flattend splited var - var_list.append(var) - var_mapping[varname] = var_list + var_mapping[varname].append(var) return var_mapping def _clone_param(self, block, v): @@ -199,7 +204,8 @@ class DistributeTranspiler: def _append_split_op(self, program, gradblocks): var_mapping = self._create_vars_from_blocklist(program, gradblocks) for varname, splited_vars in var_mapping.iteritems(): - if len(splited_vars) == 1: + # variable that don't need to split have empty splited_vars + if len(splited_vars) <= 1: continue orig_var = program.global_block().vars[varname] sections = []