未验证 提交 5f22478a 编写于 作者: W Wilber 提交者: GitHub

error message enhancement for repeated fc. test=develop (#23562)

error message enhancement for repeated fc
上级 a5bdf485
......@@ -22,37 +22,59 @@ namespace operators {
void FusionRepeatedFCReluOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionRepeatedFCReluOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusionRepeatedFCRelu");
auto sz = ctx->Inputs("W").size();
PADDLE_ENFORCE_GT(
sz, 1UL, "Inputs(W) of FusionRepeatedFCReluOp should larger than 1.");
PADDLE_ENFORCE_EQ(ctx->Inputs("Bias").size(), sz,
"Size of inputs(Bias) of FusionRepeatedFCReluOp should be "
"equal to inputs size.");
PADDLE_ENFORCE_EQ(ctx->Outputs("ReluOut").size(), sz - 1,
"Size of output(ReluOut) of FusionRepeatedFCReluOp should "
"be equal to inputs size -1.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionRepeatedFCReluOp should not be null.");
PADDLE_ENFORCE_GT(sz, 1UL, platform::errors::InvalidArgument(
"Inputs(W) of FusionRepeatedFCReluOp should "
"be greater than 1, but received value is %d.",
sz));
PADDLE_ENFORCE_EQ(
ctx->Inputs("Bias").size(), sz,
platform::errors::InvalidArgument(
"Size of inputs(Bias) of FusionRepeatedFCReluOp should be "
"equal to inputs size %d, but received value is %d.",
sz, ctx->Inputs("Bias").size()));
PADDLE_ENFORCE_EQ(
ctx->Outputs("ReluOut").size(), sz - 1,
platform::errors::InvalidArgument(
"Size of output(ReluOut) of FusionRepeatedFCReluOp should "
"be equal to inputs size minus one %d, but received value is %d",
sz - 1, ctx->Outputs("ReluOut").size()));
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FusionRepeatedFCRelu");
auto i_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(i_dims.size(), 2, "Input shape size should be 2");
PADDLE_ENFORCE_EQ(
i_dims.size(), 2,
platform::errors::InvalidArgument(
"Input shape size should be 2, but received value is %d.",
i_dims.size()));
auto w_dims = ctx->GetInputsDim("W");
auto b_dims = ctx->GetInputsDim("Bias");
PADDLE_ENFORCE_EQ(w_dims.size(), b_dims.size(),
"Shape size of weight and bias should be equal");
PADDLE_ENFORCE_EQ(w_dims.size(), sz,
"Shape size of weight and bias should be equal");
platform::errors::InvalidArgument(
"Shape size of weight and bias should be equal, but "
"weight size is %d, bias size is %d.",
w_dims.size(), b_dims.size()));
PADDLE_ENFORCE_EQ(i_dims[1], w_dims[0][0],
"inpute width should be equal with weight height");
platform::errors::InvalidArgument(
"input width should be equal to weight height, but "
"input width is %d, weight height is %d.",
i_dims[1], w_dims[0][0]));
for (size_t i = 1; i < sz; ++i) {
PADDLE_ENFORCE_EQ(w_dims[i].size(), 2,
"Every weight shape size should be 2.");
PADDLE_ENFORCE_EQ(framework::product(b_dims[i]), w_dims[i][1],
"The length of Bias must be equal with w_dims[1].");
platform::errors::InvalidArgument(
"Every weight shape size should be 2., but received "
"w_dims[%d].size() = %d.",
i, w_dims[i].size()));
PADDLE_ENFORCE_EQ(
framework::product(b_dims[i]), w_dims[i][1],
platform::errors::InvalidArgument(
"The length of Bias must be equal with w_dims[1], but received "
"product(b_dims[%d]) = %d, w_dims[%d][1] = %d.",
i, framework::product(b_dims[i]), i, w_dims[i][1]));
}
ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]});
ctx->ShareLoD("X", /*->*/ "Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册