From aa75f1e2c5d44230441aa4660c43500edd0753b9 Mon Sep 17 00:00:00 2001 From: Yancey Date: Mon, 8 Jan 2018 17:48:35 +0800 Subject: [PATCH] Create tensor in recv op (#7286) * create tensor in recv op * static global function to global function --- paddle/framework/executor.cc | 2 +- paddle/framework/executor.h | 2 ++ paddle/operators/recv_op.cc | 9 +++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index c0418c9266e..d8ef9a0fbaa 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch"; Executor::Executor(const platform::Place& place) : place_(place) {} -static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { +void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { if (var_type == proto::VarDesc::LOD_TENSOR) { var->GetMutable(); } else if (var_type == proto::VarDesc::SELECTED_ROWS) { diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h index d869e18901b..0b2b5780fed 100644 --- a/paddle/framework/executor.h +++ b/paddle/framework/executor.h @@ -45,5 +45,7 @@ class Executor { const platform::Place place_; }; +void CreateTensor(Variable* var, proto::VarDesc::VarType var_type); + } // namespace framework } // namespace paddle diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 82fceb3da7d..6f65b87d3b0 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -19,7 +19,6 @@ limitations under the License. */ #include -#include "paddle/framework/data_type.h" #include "paddle/framework/executor.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" @@ -111,9 +110,11 @@ class RecvOp : public framework::OperatorBase { << " updating param: " << param_var_name; auto *merged_grad = recv_scope.FindVar(grad_var_name); if (merged_grad == nullptr) { - // create output of merged var. - auto merged_var = recv_scope.Var(grad_var_name); - merged_var->GetMutable(); + auto *ptr = recv_scope.Var(grad_var_name); + framework::CreateTensor(ptr, + framework::ToVarType(merged_grad->Type())); + VLOG(3) << "Create Variable " << grad_var_name + << " on recv scope, which pointer is " << ptr; } if (trainer_count > 1) { -- GitLab