未验证 提交 a7e21cbe 编写于 作者: A Aurelius84 提交者: GitHub

Move input_size check into RunTime phrase of gru_unit_op and refine error message (#24776)

* Add IsRuntime judgement in GRUUnit test=develop

* add IsRuntime judgement is GradOp test=develop

* Refine Error Message of SelecteInput/Output test=develop

* refine Error Message of RNNMemoryHelperOp test=develop
上级 d160e57a
......@@ -25,19 +25,14 @@ class GRUUnitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUUnitOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
"Input(%s) of GRUUnitOp should not be null.", "HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasOutput("Gate"),
"Output(%s) of GRUUnitOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"),
"Output(%s) of GRUUnitOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUUnitOp should not be null.", "Hidden");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnit");
OP_INOUT_CHECK(ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev",
"GRUUnit");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnit");
OP_INOUT_CHECK(ctx->HasOutput("Gate"), "Output", "Gate", "GRUUnit");
OP_INOUT_CHECK(ctx->HasOutput("ResetHiddenPrev"), "Output",
"ResetHiddenPrev", "GRUUnit");
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRUUnit");
auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight");
......@@ -46,23 +41,45 @@ class GRUUnitOp : public framework::OperatorWithKernel {
int frame_size = hidden_prev_dims[1];
int weight_height = weight_dims[0];
int weight_width = weight_dims[1];
PADDLE_ENFORCE_EQ(
input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUUnitOp.");
if (ctx->IsRuntime() || input_size >= 0) {
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
platform::errors::InvalidArgument(
"The second dimension of Input(Input) must be 3 "
"times of frame_size in GRUUnitOp, but received %d "
"(Input) vs %d (frame_size).",
input_size, frame_size));
}
PADDLE_ENFORCE_EQ(
weight_height, frame_size,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] "
"(frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] "
"(frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(
bias_height, 1,
platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
PADDLE_ENFORCE_EQ(
bias_width, frame_size * 3,
platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
}
ctx->SetOutputDim("Gate", {batch_size, frame_size * 3});
ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size});
......@@ -143,21 +160,16 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUUnitGradOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.",
"HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitGradOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("Gate"),
"Input(%s) of GRUUnitGradOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Hidden");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnitGrad");
OP_INOUT_CHECK(ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev",
"GRUUnitGrad");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnitGrad");
OP_INOUT_CHECK(ctx->HasInput("Gate"), "Input", "Gate", "GRUUnitGrad");
OP_INOUT_CHECK(ctx->HasInput("ResetHiddenPrev"), "Input", "ResetHiddenPrev",
"GRUUnitGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), "Input",
"Hidden@GRAD", "GRUUnitGrad");
auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight");
......@@ -166,23 +178,46 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
int frame_size = hidden_prev_dims[1];
int weight_height = weight_dims[0];
int weight_width = weight_dims[1];
if (ctx->IsRuntime() || input_size >= 0) {
PADDLE_ENFORCE_EQ(
input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUUnitOp.");
platform::errors::InvalidArgument(
"The second dimension of Input(Input) must be 3 "
"times of frame_size in GRUUnitGradOp, but received %d "
"(Input) vs %d (frame_size).",
input_size, frame_size));
}
PADDLE_ENFORCE_EQ(
weight_height, frame_size,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] "
"(frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] "
"(frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(
bias_height, 1,
platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
PADDLE_ENFORCE_EQ(
bias_width, frame_size * 3,
platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims);
......
......@@ -30,15 +30,15 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const platform::Place &dev_place) const override {
auto mem_var_name = Input("X");
auto *mem_var = scope.FindVar(mem_var_name);
PADDLE_ENFORCE(mem_var != nullptr,
"Cannot find mem_var in scope, mem_var_name is %s",
mem_var_name);
PADDLE_ENFORCE_NOT_NULL(
mem_var, platform::errors::NotFound("Cannot find mem_var: %s in scope.",
mem_var_name));
auto out_name = this->Output("Out");
auto *out_var = scope.FindVar(out_name);
PADDLE_ENFORCE(out_var != nullptr,
"Cannot find out_var in scope, out_var_name is %s",
out_name);
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound("Cannot find out_var: %s in scope.",
out_name));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
......@@ -53,10 +53,9 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of rnn_memory_helper op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output of rnn_memory_helper op should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "RNNMemoryHelper");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "RNNMemoryHelper");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -91,10 +90,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
auto in_grad_var_name = Output(framework::GradVarName("X"));
auto *in_grad_var = scope.FindVar(in_grad_var_name);
PADDLE_ENFORCE(in_grad_var != nullptr,
"Cannot find in_grad_var in scope, name is %s",
in_grad_var_name);
PADDLE_ENFORCE_NOT_NULL(
in_grad_var,
platform::errors::NotFound("Cannot find in_grad_var: %s in scope.",
in_grad_var_name));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
......@@ -143,11 +142,9 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
auto x_grad_name = framework::GradVarName("X");
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name),
"Gradient of Input(X) in rnn_memory_helper_grad of should "
"not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of rnn_memory_helper_grad of should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "RNNMemoryHelperGrad");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"RNNMemoryHelperGrad");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
......
......@@ -80,12 +80,9 @@ specifying the output branchi.
class SelectInputInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInputs("X"), true,
"SelectInputOp must have input X.");
PADDLE_ENFORCE_EQ(context->HasInput("Mask"), true,
"SelectInputOp must have input Mask.");
PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true,
"SelectInputOp must have output Out.");
OP_INOUT_CHECK(context->HasInputs("X"), "Input", "X", "SelectInput");
OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectInput");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "SelectInput");
}
};
......
......@@ -78,12 +78,9 @@ specify which output branch should copy the input.
class SelectOutputInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
"SelectOutputOp must have input X.");
PADDLE_ENFORCE_EQ(context->HasInput("Mask"), true,
"SelectOutputOp must have input Mask.");
PADDLE_ENFORCE_EQ(context->HasOutputs("Out"), true,
"SelectOutputOp must have output Out.");
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput");
OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput");
OP_INOUT_CHECK(context->HasOutputs("Out"), "Output", "Out", "SelectOutput");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册