diff --git a/paddle/fluid/operators/split_selected_rows_op.cc b/paddle/fluid/operators/split_selected_rows_op.cc index b0e21e01ecc55d666f43aee621e9b5a05b347e1d..e1ce3d0c1bf11e9a623e4e9adc8f08f5069f4d94 100644 --- a/paddle/fluid/operators/split_selected_rows_op.cc +++ b/paddle/fluid/operators/split_selected_rows_op.cc @@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { } }; +class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + for (auto &out_var : op_desc.Output("Out")) { + block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS); + } + } +}; + class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp, ops::SplitSelectedRowsOpMaker, - ops::SplitSelectedRowsGradMaker); + ops::SplitSelectedRowsGradMaker, + ops::SplitSelectedRowsOpInferVarType); REGISTER_OP_CPU_KERNEL( split_selected_rows, ops::SplitSelectedRowsOpKernel); diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 8da9ca290b22ae69b1fd195d8614c31dc4e13e00..497bcf93a379e7eb0ce8a94d5702349b0547d14f 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -276,6 +276,7 @@ class DistributeTranspiler: pserver_program.global_block().create_var( name=orig_var_name, persistable=True, + type=v.type, dtype=v.dtype, shape=v.shape) print("create origin var: ", orig_var_name) @@ -283,6 +284,7 @@ class DistributeTranspiler: var = pserver_program.global_block().create_var( name="%s.trainer_%d" % (orig_var_name, trainer_id), persistable=False, + type=v.type, dtype=v.dtype, shape=v.shape) recv_inputs.append(var) @@ -551,11 +553,12 @@ class DistributeTranspiler: type="sum", inputs={"X": vars2merge}, outputs={"Out": merged_var}) - optimize_block.append_op( - type="scale", - inputs={"X": merged_var}, - outputs={"Out": merged_var}, - attrs={"scale": 1.0 / float(self.trainers)}) + if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS: + optimize_block.append_op( + type="scale", + inputs={"X": merged_var}, + outputs={"Out": merged_var}, + attrs={"scale": 1.0 / float(self.trainers)}) new_inputs[key] = merged_var elif key == "Param": # param is already created on global program