未验证 提交 a7e21cbe 编写于 作者: A Aurelius84 提交者: GitHub

Move input_size check into RunTime phrase of gru_unit_op and refine error message (#24776)

* Add IsRuntime judgement in GRUUnit test=develop

* add IsRuntime judgement is GradOp test=develop

* Refine Error Message of SelecteInput/Output test=develop

* refine Error Message of RNNMemoryHelperOp test=develop
上级 d160e57a
...@@ -25,19 +25,14 @@ class GRUUnitOp : public framework::OperatorWithKernel { ...@@ -25,19 +25,14 @@ class GRUUnitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnit");
"Input(%s) of GRUUnitOp should not be null.", "Input"); OP_INOUT_CHECK(ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev",
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"), "GRUUnit");
"Input(%s) of GRUUnitOp should not be null.", "HiddenPrev"); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnit");
PADDLE_ENFORCE(ctx->HasInput("Weight"), OP_INOUT_CHECK(ctx->HasOutput("Gate"), "Output", "Gate", "GRUUnit");
"Input(%s) of GRUUnitOp should not be null.", "Weight"); OP_INOUT_CHECK(ctx->HasOutput("ResetHiddenPrev"), "Output",
PADDLE_ENFORCE(ctx->HasOutput("Gate"), "ResetHiddenPrev", "GRUUnit");
"Output(%s) of GRUUnitOp should not be null.", "Gate"); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRUUnit");
PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"),
"Output(%s) of GRUUnitOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUUnitOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
...@@ -46,23 +41,45 @@ class GRUUnitOp : public framework::OperatorWithKernel { ...@@ -46,23 +41,45 @@ class GRUUnitOp : public framework::OperatorWithKernel {
int frame_size = hidden_prev_dims[1]; int frame_size = hidden_prev_dims[1];
int weight_height = weight_dims[0]; int weight_height = weight_dims[0];
int weight_width = weight_dims[1]; int weight_width = weight_dims[1];
PADDLE_ENFORCE_EQ( if (ctx->IsRuntime() || input_size >= 0) {
input_size, frame_size * 3, PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUUnitOp."); platform::errors::InvalidArgument(
"The second dimension of Input(Input) must be 3 "
"times of frame_size in GRUUnitOp, but received %d "
"(Input) vs %d (frame_size).",
input_size, frame_size));
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_height, frame_size, weight_height, frame_size,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] "
"(frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3, weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] "
"(frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0]; int bias_height = bias_dims[0];
int bias_width = bias_dims[1]; int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1, PADDLE_ENFORCE_EQ(
"The shape of Bias must be [1, frame_size * 3]."); bias_height, 1,
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3]."); "The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
PADDLE_ENFORCE_EQ(
bias_width, frame_size * 3,
platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
} }
ctx->SetOutputDim("Gate", {batch_size, frame_size * 3}); ctx->SetOutputDim("Gate", {batch_size, frame_size * 3});
ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size}); ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size});
...@@ -143,21 +160,16 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -143,21 +160,16 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnitGrad");
"Input(%s) of GRUUnitGradOp should not be null.", "Input"); OP_INOUT_CHECK(ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev",
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"), "GRUUnitGrad");
"Input(%s) of GRUUnitGradOp should not be null.", OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnitGrad");
"HiddenPrev"); OP_INOUT_CHECK(ctx->HasInput("Gate"), "Input", "Gate", "GRUUnitGrad");
PADDLE_ENFORCE(ctx->HasInput("Weight"), OP_INOUT_CHECK(ctx->HasInput("ResetHiddenPrev"), "Input", "ResetHiddenPrev",
"Input(%s) of GRUUnitGradOp should not be null.", "Weight"); "GRUUnitGrad");
PADDLE_ENFORCE(ctx->HasInput("Gate"), OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), "Input",
"Input(%s) of GRUUnitGradOp should not be null.", "Gate"); "Hidden@GRAD", "GRUUnitGrad");
PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Hidden");
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
...@@ -166,23 +178,46 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -166,23 +178,46 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
int frame_size = hidden_prev_dims[1]; int frame_size = hidden_prev_dims[1];
int weight_height = weight_dims[0]; int weight_height = weight_dims[0];
int weight_width = weight_dims[1]; int weight_width = weight_dims[1];
if (ctx->IsRuntime() || input_size >= 0) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_size, frame_size * 3, input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUUnitOp."); platform::errors::InvalidArgument(
"The second dimension of Input(Input) must be 3 "
"times of frame_size in GRUUnitGradOp, but received %d "
"(Input) vs %d (frame_size).",
input_size, frame_size));
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_height, frame_size, weight_height, frame_size,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] "
"(frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3, weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] "
"(frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0]; int bias_height = bias_dims[0];
int bias_width = bias_dims[1]; int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3]."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, bias_height, 1,
"The shape of Bias must be [1, frame_size * 3]."); platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
PADDLE_ENFORCE_EQ(
bias_width, frame_size * 3,
platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
auto bias_grad_name = framework::GradVarName("Bias"); auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name)) if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims); ctx->SetOutputDim(bias_grad_name, bias_dims);
......
...@@ -30,15 +30,15 @@ class RNNMemoryHelperOp : public framework::OperatorBase { ...@@ -30,15 +30,15 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto mem_var_name = Input("X"); auto mem_var_name = Input("X");
auto *mem_var = scope.FindVar(mem_var_name); auto *mem_var = scope.FindVar(mem_var_name);
PADDLE_ENFORCE(mem_var != nullptr, PADDLE_ENFORCE_NOT_NULL(
"Cannot find mem_var in scope, mem_var_name is %s", mem_var, platform::errors::NotFound("Cannot find mem_var: %s in scope.",
mem_var_name); mem_var_name));
auto out_name = this->Output("Out"); auto out_name = this->Output("Out");
auto *out_var = scope.FindVar(out_name); auto *out_var = scope.FindVar(out_name);
PADDLE_ENFORCE(out_var != nullptr, PADDLE_ENFORCE_NOT_NULL(
"Cannot find out_var in scope, out_var_name is %s", out_var, platform::errors::NotFound("Cannot find out_var: %s in scope.",
out_name); out_name));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
...@@ -53,10 +53,9 @@ class RNNMemoryHelperOp : public framework::OperatorBase { ...@@ -53,10 +53,9 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "RNNMemoryHelper");
"Input(X) of rnn_memory_helper op should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "RNNMemoryHelper");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output of rnn_memory_helper op should not be null.");
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -91,10 +90,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { ...@@ -91,10 +90,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
auto in_grad_var_name = Output(framework::GradVarName("X")); auto in_grad_var_name = Output(framework::GradVarName("X"));
auto *in_grad_var = scope.FindVar(in_grad_var_name); auto *in_grad_var = scope.FindVar(in_grad_var_name);
PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(in_grad_var != nullptr, in_grad_var,
"Cannot find in_grad_var in scope, name is %s", platform::errors::NotFound("Cannot find in_grad_var: %s in scope.",
in_grad_var_name); in_grad_var_name));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
...@@ -143,11 +142,9 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase { ...@@ -143,11 +142,9 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "RNNMemoryHelperGrad");
"Gradient of Input(X) in rnn_memory_helper_grad of should " OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"not be null."); "RNNMemoryHelperGrad");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of rnn_memory_helper_grad of should not be null.");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name); ctx->ShareLoD("X", /*->*/ x_grad_name);
} }
......
...@@ -80,12 +80,9 @@ specifying the output branchi. ...@@ -80,12 +80,9 @@ specifying the output branchi.
class SelectInputInferShape : public framework::InferShapeBase { class SelectInputInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInputs("X"), true, OP_INOUT_CHECK(context->HasInputs("X"), "Input", "X", "SelectInput");
"SelectInputOp must have input X."); OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectInput");
PADDLE_ENFORCE_EQ(context->HasInput("Mask"), true, OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "SelectInput");
"SelectInputOp must have input Mask.");
PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true,
"SelectInputOp must have output Out.");
} }
}; };
......
...@@ -78,12 +78,9 @@ specify which output branch should copy the input. ...@@ -78,12 +78,9 @@ specify which output branch should copy the input.
class SelectOutputInferShape : public framework::InferShapeBase { class SelectOutputInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("X"), true, OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput");
"SelectOutputOp must have input X."); OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput");
PADDLE_ENFORCE_EQ(context->HasInput("Mask"), true, OP_INOUT_CHECK(context->HasOutputs("Out"), "Output", "Out", "SelectOutput");
"SelectOutputOp must have input Mask.");
PADDLE_ENFORCE_EQ(context->HasOutputs("Out"), true,
"SelectOutputOp must have output Out.");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册