diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 76e8734f1371d0874d8b10f597a9bb6989de9297..009f079e839ae64572cb80c82cd8a44bfa89a54a 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -56,8 +56,6 @@ def split_dense_variable(var_list, (block_id) * block_size)) block = VarBlock(var.name, block_id, curr_block_size) blocks.append(str(block)) - print("$$ splited var: ", var.name, var.shape, split_count, len(blocks), - block_size) return blocks @@ -126,7 +124,7 @@ class DistributeTranspiler: # let send_op know which endpoint to send which var, eplist is of the same # order of send_inputs. eplist = split_method(send_inputs, pserver_endpoints) - # create mapping of endpoint -> var to create pserver side program + # create mapping of endpoint -> splited var to create pserver side program self.param_grad_ep_mapping = dict() for i, ep in enumerate(eplist): param = send_outputs[i] @@ -142,7 +140,6 @@ class DistributeTranspiler: outputs={"Out": send_outputs}, attrs={"endpoints": pserver_endpoints, "epmap": eplist}) - # step4 for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: @@ -187,21 +184,6 @@ class DistributeTranspiler: var_mapping[varname].append(var) return var_mapping - def _clone_param(self, block, v): - assert isinstance(v, Parameter) - new_p = Parameter( - block=block, - shape=v.shape, - dtype=v.dtype, - type=v.type, - lod_level=v.lod_level, - stop_gradient=v.stop_gradient, - trainable=v.trainable, - optimize_attr=v.optimize_attr, - regularizer=v.regularizer, - name=v.name) - block.vars[new_p.name] = new_p - def _clone_var(self, block, var): assert isinstance(var, Variable) return block.create_var( @@ -210,7 +192,9 @@ class DistributeTranspiler: dtype=var.dtype, type=var.type, lod_level=var.lod_level, - persistable=var.persistable) + # HACK: let all param in pserver persistable so child + # program in recv can get them + persistable=True) def _append_split_op(self, program, gradblocks): var_mapping = self._create_vars_from_blocklist(program, gradblocks) @@ -318,9 +302,10 @@ class DistributeTranspiler: return tmpvar = program.global_block().create_var( name=param_block.name, - persistable=param_block.persistable, + persistable=True, dtype=param_block.dtype, shape=param_block.shape) + new_inputs[key] = tmpvar for key, var in opt_op.inputs.iteritems(): @@ -330,7 +315,6 @@ class DistributeTranspiler: param_shape = new_inputs["Param"].shape new_shape = self._get_optimizer_input_shape(opt_op.type, key, var.shape, param_shape) - print("var, new shape", key, var.name, new_shape) tmpvar = program.global_block().create_var( name=var.name, persistable=var.persistable, @@ -338,7 +322,8 @@ class DistributeTranspiler: shape=new_shape) new_inputs[key] = tmpvar - # FIXME: change outputs ParamOut + # change outputs ParamOut variable + opt_op.outputs["ParamOut"] = new_inputs["Param"] program.global_block().append_op( type=opt_op.type, inputs=new_inputs, @@ -380,6 +365,7 @@ class DistributeTranspiler: else: self._append_pserver_non_opt_ops(optimize_sub_program, opt_op) + print("####", optimize_sub_program) pserver_program.global_block().append_op( type="recv", inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"] @@ -400,3 +386,53 @@ class DistributeTranspiler: }) pserver_program.sync_with_cpp() return pserver_program + + def get_startup_program(self, endpoint): + """ + Get startup program for current parameter server. + Modify operator input variables if there are variables that + was splited to several blocks. + """ + s_prog = Program() + orig_s_prog = framework.default_startup_program() + params = self.param_grad_ep_mapping[endpoint]["params"] + + def _get_splited_name_and_shape(varname): + for idx, splited_param in enumerate(params): + pname = splited_param.name + if pname.startswith(varname) and varname != pname: + return pname, splited_param.shape + return "", [] + + # 1. create vars + created_var_map = dict() + for var in params: + print("%%%% append var", var.name, var.shape) + tmpvar = s_prog.global_block().create_var( + name=var.name, + persistable=True, + dtype=var.dtype, + shape=var.shape) + created_var_map[var.name] = tmpvar + + # 2. rename op outputs + for op in orig_s_prog.global_block().ops: + new_outputs = dict() + for key, var in op.outputs.iteritems(): + newname, _ = _get_splited_name_and_shape(var.name) + if newname: + new_outputs[key] = created_var_map[newname] + else: + new_outputs[key] = var + # do not append startup op if var is not on this pserver + var_on_pserver = False + for _, var in new_outputs.iteritems(): + if var.name in created_var_map: + var_on_pserver = True + if var_on_pserver: + s_prog.global_block().append_op( + type=op.type, + inputs=op.inputs, + outputs=new_outputs, + attrs=op.attrs) + return s_prog