未验证 提交 aa75f1e2 编写于 作者: Y Yancey 提交者: GitHub

Create tensor in recv op (#7286)

* create tensor in recv op

* static global function to global function
上级 2d10c75b
......@@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch";
Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {
......
......@@ -45,5 +45,7 @@ class Executor {
const platform::Place place_;
};
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type);
} // namespace framework
} // namespace paddle
......@@ -19,7 +19,6 @@ limitations under the License. */
#include <unistd.h>
#include "paddle/framework/data_type.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
......@@ -111,9 +110,11 @@ class RecvOp : public framework::OperatorBase {
<< " updating param: " << param_var_name;
auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) {
// create output of merged var.
auto merged_var = recv_scope.Var(grad_var_name);
merged_var->GetMutable<framework::LoDTensor>();
auto *ptr = recv_scope.Var(grad_var_name);
framework::CreateTensor(ptr,
framework::ToVarType(merged_grad->Type()));
VLOG(3) << "Create Variable " << grad_var_name
<< " on recv scope, which pointer is " << ptr;
}
if (trainer_count > 1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册