未验证 提交 aa75f1e2 编写于 作者: Y Yancey 提交者: GitHub

Create tensor in recv op (#7286)

* create tensor in recv op

* static global function to global function
上级 2d10c75b
...@@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch"; ...@@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch";
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) { if (var_type == proto::VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarDesc::SELECTED_ROWS) { } else if (var_type == proto::VarDesc::SELECTED_ROWS) {
......
...@@ -45,5 +45,7 @@ class Executor { ...@@ -45,5 +45,7 @@ class Executor {
const platform::Place place_; const platform::Place place_;
}; };
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,7 +19,6 @@ limitations under the License. */ ...@@ -19,7 +19,6 @@ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#include "paddle/framework/data_type.h"
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
...@@ -111,9 +110,11 @@ class RecvOp : public framework::OperatorBase { ...@@ -111,9 +110,11 @@ class RecvOp : public framework::OperatorBase {
<< " updating param: " << param_var_name; << " updating param: " << param_var_name;
auto *merged_grad = recv_scope.FindVar(grad_var_name); auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) { if (merged_grad == nullptr) {
// create output of merged var. auto *ptr = recv_scope.Var(grad_var_name);
auto merged_var = recv_scope.Var(grad_var_name); framework::CreateTensor(ptr,
merged_var->GetMutable<framework::LoDTensor>(); framework::ToVarType(merged_grad->Type()));
VLOG(3) << "Create Variable " << grad_var_name
<< " on recv scope, which pointer is " << ptr;
} }
if (trainer_count > 1) { if (trainer_count > 1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册