From 6033c1a27839eadab7f84a5281905d3b23de86fc Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 27 Aug 2018 03:03:58 +0000 Subject: [PATCH] Add error info & remove data sharing between input and output in rnn_memory_helper_op --- paddle/fluid/operators/rnn_memory_helper_op.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/rnn_memory_helper_op.cc b/paddle/fluid/operators/rnn_memory_helper_op.cc index 23e5fc1112d..13df1d4b4bb 100644 --- a/paddle/fluid/operators/rnn_memory_helper_op.cc +++ b/paddle/fluid/operators/rnn_memory_helper_op.cc @@ -42,7 +42,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase { auto *out_tensor = out_var->GetMutable(); auto &mem_tensor = mem_var->Get(); - out_tensor->ShareDataWith(mem_tensor); + framework::TensorCopySync(mem_tensor, dev_place, out_tensor); out_tensor->set_lod(mem_tensor.lod()); } }; @@ -50,8 +50,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase { class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), ""); - PADDLE_ENFORCE(ctx->HasOutput("Out"), ""); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of rnn_memory_helper op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output of rnn_memory_helper op should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -107,7 +109,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { } else { auto &out_grad_tensor = out_grad_var->Get(); auto *in_grad_tensor = in_grad_var->GetMutable(); - in_grad_tensor->ShareDataWith(out_grad_tensor); + framework::TensorCopySync(out_grad_tensor, dev_place, in_grad_tensor); in_grad_tensor->set_lod(out_grad_tensor.lod()); } } @@ -133,8 +135,11 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { auto x_grad_name = framework::GradVarName("X"); - PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), ""); - PADDLE_ENFORCE(ctx->HasInput("X"), ""); + PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), + "Gradient of Input(X) in rnn_memory_helper_grad of should " + "not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of rnn_memory_helper_grad of should not be null."); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ x_grad_name); } -- GitLab