提交 afbc435a 编写于 作者: X xuezhong

fix infershape check bug

test=develop
上级 5663fbfb
......@@ -28,7 +28,8 @@ class AucOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Label should not be null.");
auto predict_width = ctx->GetInputDim("Predict")[1];
PADDLE_ENFORCE_EQ(predict_width, 2, "Only support binary classification");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_width, 2,
"Only support binary classification");
auto predict_height = ctx->GetInputDim("Predict")[0];
auto label_height = ctx->GetInputDim("Label")[0];
......
......@@ -43,8 +43,25 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"),
"If weights are provided, must specify both "
"inside and outside weights.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims);
auto dims = ctx->GetInputDim("InsideWeight");
bool check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(dims, x_dims);
}
dims = ctx->GetInputDim("OutsideWeight");
check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(dims, x_dims);
}
}
ctx->SetOutputDim("Diff", x_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册