提交 6033c1a2 编写于 作者: J jerrywgz

Add error info & remove data sharing between input and output in rnn_memory_helper_op

上级 d361624c
...@@ -42,7 +42,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase { ...@@ -42,7 +42,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>(); auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto &mem_tensor = mem_var->Get<framework::LoDTensor>(); auto &mem_tensor = mem_var->Get<framework::LoDTensor>();
out_tensor->ShareDataWith(mem_tensor); framework::TensorCopySync(mem_tensor, dev_place, out_tensor);
out_tensor->set_lod(mem_tensor.lod()); out_tensor->set_lod(mem_tensor.lod());
} }
}; };
...@@ -50,8 +50,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase { ...@@ -50,8 +50,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), ""); PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasOutput("Out"), ""); "Input(X) of rnn_memory_helper op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output of rnn_memory_helper op should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -107,7 +109,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { ...@@ -107,7 +109,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
} else { } else {
auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>(); auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>();
auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>(); auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>();
in_grad_tensor->ShareDataWith(out_grad_tensor); framework::TensorCopySync(out_grad_tensor, dev_place, in_grad_tensor);
in_grad_tensor->set_lod(out_grad_tensor.lod()); in_grad_tensor->set_lod(out_grad_tensor.lod());
} }
} }
...@@ -133,8 +135,11 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase { ...@@ -133,8 +135,11 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), ""); PADDLE_ENFORCE(ctx->HasOutput(x_grad_name),
PADDLE_ENFORCE(ctx->HasInput("X"), ""); "Gradient of Input(X) in rnn_memory_helper_grad of should "
"not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of rnn_memory_helper_grad of should not be null.");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name); ctx->ShareLoD("X", /*->*/ x_grad_name);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册