未验证 提交 5f22478a 编写于 作者: W Wilber 提交者: GitHub

error message enhancement for repeated fc. test=develop (#23562)

error message enhancement for repeated fc
上级 a5bdf485
...@@ -22,37 +22,59 @@ namespace operators { ...@@ -22,37 +22,59 @@ namespace operators {
void FusionRepeatedFCReluOp::InferShape( void FusionRepeatedFCReluOp::InferShape(
framework::InferShapeContext* ctx) const { framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusionRepeatedFCRelu");
"Input(X) of FusionRepeatedFCReluOp should not be null.");
auto sz = ctx->Inputs("W").size(); auto sz = ctx->Inputs("W").size();
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(sz, 1UL, platform::errors::InvalidArgument(
sz, 1UL, "Inputs(W) of FusionRepeatedFCReluOp should larger than 1."); "Inputs(W) of FusionRepeatedFCReluOp should "
PADDLE_ENFORCE_EQ(ctx->Inputs("Bias").size(), sz, "be greater than 1, but received value is %d.",
"Size of inputs(Bias) of FusionRepeatedFCReluOp should be " sz));
"equal to inputs size."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(ctx->Outputs("ReluOut").size(), sz - 1, ctx->Inputs("Bias").size(), sz,
"Size of output(ReluOut) of FusionRepeatedFCReluOp should " platform::errors::InvalidArgument(
"be equal to inputs size -1."); "Size of inputs(Bias) of FusionRepeatedFCReluOp should be "
PADDLE_ENFORCE(ctx->HasOutput("Out"), "equal to inputs size %d, but received value is %d.",
"Output(Out) of FusionRepeatedFCReluOp should not be null."); sz, ctx->Inputs("Bias").size()));
PADDLE_ENFORCE_EQ(
ctx->Outputs("ReluOut").size(), sz - 1,
platform::errors::InvalidArgument(
"Size of output(ReluOut) of FusionRepeatedFCReluOp should "
"be equal to inputs size minus one %d, but received value is %d",
sz - 1, ctx->Outputs("ReluOut").size()));
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FusionRepeatedFCRelu");
auto i_dims = ctx->GetInputDim("X"); auto i_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(i_dims.size(), 2, "Input shape size should be 2"); PADDLE_ENFORCE_EQ(
i_dims.size(), 2,
platform::errors::InvalidArgument(
"Input shape size should be 2, but received value is %d.",
i_dims.size()));
auto w_dims = ctx->GetInputsDim("W"); auto w_dims = ctx->GetInputsDim("W");
auto b_dims = ctx->GetInputsDim("Bias"); auto b_dims = ctx->GetInputsDim("Bias");
PADDLE_ENFORCE_EQ(w_dims.size(), b_dims.size(), PADDLE_ENFORCE_EQ(w_dims.size(), b_dims.size(),
"Shape size of weight and bias should be equal"); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(w_dims.size(), sz, "Shape size of weight and bias should be equal, but "
"Shape size of weight and bias should be equal"); "weight size is %d, bias size is %d.",
w_dims.size(), b_dims.size()));
PADDLE_ENFORCE_EQ(i_dims[1], w_dims[0][0], PADDLE_ENFORCE_EQ(i_dims[1], w_dims[0][0],
"inpute width should be equal with weight height"); platform::errors::InvalidArgument(
"input width should be equal to weight height, but "
"input width is %d, weight height is %d.",
i_dims[1], w_dims[0][0]));
for (size_t i = 1; i < sz; ++i) { for (size_t i = 1; i < sz; ++i) {
PADDLE_ENFORCE_EQ(w_dims[i].size(), 2, PADDLE_ENFORCE_EQ(w_dims[i].size(), 2,
"Every weight shape size should be 2."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(framework::product(b_dims[i]), w_dims[i][1], "Every weight shape size should be 2., but received "
"The length of Bias must be equal with w_dims[1]."); "w_dims[%d].size() = %d.",
i, w_dims[i].size()));
PADDLE_ENFORCE_EQ(
framework::product(b_dims[i]), w_dims[i][1],
platform::errors::InvalidArgument(
"The length of Bias must be equal with w_dims[1], but received "
"product(b_dims[%d]) = %d, w_dims[%d][1] = %d.",
i, framework::product(b_dims[i]), i, w_dims[i][1]));
} }
ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]}); ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册