From 5faebab375d5e039f5f7cc3169b8de8167494d31 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Fri, 12 Jan 2018 14:18:55 +0800 Subject: [PATCH] Done, need support selectedrows --- .../paddle/v2/fluid/distribute_transpiler.py | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 59e74e0d6f2..d17f9815cca 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -148,7 +148,7 @@ class DistributeTranspiler: concat = program.global_block().append_op( type="concat", inputs={"X": splited_var}, - outputs={"Out": orig_param}, + outputs={"Out": [orig_param]}, attrs={"axis": 0}) def _create_vars_from_blocklist(self, program, block_list): @@ -420,7 +420,6 @@ class DistributeTranspiler: else: self._append_pserver_non_opt_ops(optimize_sub_program, pserver_program, opt_op) - print("****subprogram", optimize_sub_program) pserver_program.global_block().append_op( type="recv", inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"] @@ -463,7 +462,6 @@ class DistributeTranspiler: pserver_vars = pserver_program.global_block().vars created_var_map = dict() for _, var in pserver_vars.iteritems(): - print("create var for startup", var.name, var.shape) tmpvar = s_prog.global_block().create_var( name=var.name, persistable=var.persistable, @@ -485,21 +483,11 @@ class DistributeTranspiler: op_on_pserver = True new_outputs[key] = pserver_vars[var.name] - # newname, _ = _get_splited_name_and_shape(var.name) - # if newname: - # print("updating output", newname, created_var_map[newname]) - # new_outputs[key] = created_var_map[newname] - # else: - # print("no update output", key, var) - # new_outputs[key] = var - # if var.name in created_var_map or \ - # newname: - # op_on_pserver = True - if op_on_pserver: - if op.type in ["gaussian_random", "fill_constant"]: + if op.type in [ + "gaussian_random", "fill_constant", "uniform_random" + ]: op.attrs["shape"] = new_outputs["Out"].shape - print("updated shape", op.attrs["shape"]) s_prog.global_block().append_op( type=op.type, inputs=op.inputs, -- GitLab