From 6e83c0030148895e39e02093e385ceb639c61022 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 28 Feb 2018 19:03:17 +0800 Subject: [PATCH] Registry var type infer in split_selected_rows op --- python/paddle/fluid/distribute_transpiler.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 8da9ca290b2..497bcf93a37 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -276,6 +276,7 @@ class DistributeTranspiler: pserver_program.global_block().create_var( name=orig_var_name, persistable=True, + type=v.type, dtype=v.dtype, shape=v.shape) print("create origin var: ", orig_var_name) @@ -283,6 +284,7 @@ class DistributeTranspiler: var = pserver_program.global_block().create_var( name="%s.trainer_%d" % (orig_var_name, trainer_id), persistable=False, + type=v.type, dtype=v.dtype, shape=v.shape) recv_inputs.append(var) @@ -551,11 +553,12 @@ class DistributeTranspiler: type="sum", inputs={"X": vars2merge}, outputs={"Out": merged_var}) - optimize_block.append_op( - type="scale", - inputs={"X": merged_var}, - outputs={"Out": merged_var}, - attrs={"scale": 1.0 / float(self.trainers)}) + if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS: + optimize_block.append_op( + type="scale", + inputs={"X": merged_var}, + outputs={"Out": merged_var}, + attrs={"scale": 1.0 / float(self.trainers)}) new_inputs[key] = merged_var elif key == "Param": # param is already created on global program -- GitLab