提交 6033c1a2 编写于 作者: J jerrywgz

Add error info & remove data sharing between input and output in rnn_memory_helper_op

上级 d361624c
......@@ -42,7 +42,7 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto &mem_tensor = mem_var->Get<framework::LoDTensor>();
out_tensor->ShareDataWith(mem_tensor);
framework::TensorCopySync(mem_tensor, dev_place, out_tensor);
out_tensor->set_lod(mem_tensor.lod());
}
};
......@@ -50,8 +50,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of rnn_memory_helper op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output of rnn_memory_helper op should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -107,7 +109,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
} else {
auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>();
auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>();
in_grad_tensor->ShareDataWith(out_grad_tensor);
framework::TensorCopySync(out_grad_tensor, dev_place, in_grad_tensor);
in_grad_tensor->set_lod(out_grad_tensor.lod());
}
}
......@@ -133,8 +135,11 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
auto x_grad_name = framework::GradVarName("X");
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), "");
PADDLE_ENFORCE(ctx->HasInput("X"), "");
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name),
"Gradient of Input(X) in rnn_memory_helper_grad of should "
"not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of rnn_memory_helper_grad of should not be null.");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册