diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index c0418c9266e257bd7567861543e557f354451b17..d8ef9a0fbaa7ba18d78060bd5b9605458cd9b1a2 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch"; Executor::Executor(const platform::Place& place) : place_(place) {} -static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { +void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { if (var_type == proto::VarDesc::LOD_TENSOR) { var->GetMutable(); } else if (var_type == proto::VarDesc::SELECTED_ROWS) { diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h index d869e18901b82959a40cc296aa0844c20ea63ac1..0b2b5780fed1ef48ba78f44112fb0a88b477b796 100644 --- a/paddle/framework/executor.h +++ b/paddle/framework/executor.h @@ -45,5 +45,7 @@ class Executor { const platform::Place place_; }; +void CreateTensor(Variable* var, proto::VarDesc::VarType var_type); + } // namespace framework } // namespace paddle diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 82fceb3da7d396bcfc1d95baccc4ee36b87f4d39..6f65b87d3b06c1d8d453f42194a277accdbc1164 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -19,7 +19,6 @@ limitations under the License. */ #include -#include "paddle/framework/data_type.h" #include "paddle/framework/executor.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" @@ -111,9 +110,11 @@ class RecvOp : public framework::OperatorBase { << " updating param: " << param_var_name; auto *merged_grad = recv_scope.FindVar(grad_var_name); if (merged_grad == nullptr) { - // create output of merged var. - auto merged_var = recv_scope.Var(grad_var_name); - merged_var->GetMutable(); + auto *ptr = recv_scope.Var(grad_var_name); + framework::CreateTensor(ptr, + framework::ToVarType(merged_grad->Type())); + VLOG(3) << "Create Variable " << grad_var_name + << " on recv scope, which pointer is " << ptr; } if (trainer_count > 1) {