From afbc435adf63f59a0863bb3424f9fb33d5e35131 Mon Sep 17 00:00:00 2001 From: xuezhong Date: Tue, 16 Apr 2019 08:13:55 +0000 Subject: [PATCH] fix infershape check bug test=develop --- paddle/fluid/operators/metrics/auc_op.cc | 3 ++- paddle/fluid/operators/smooth_l1_loss_op.cc | 21 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index 4670eb23b3d..001d2693688 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -28,7 +28,8 @@ class AucOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("Label"), "Input of Label should not be null."); auto predict_width = ctx->GetInputDim("Predict")[1]; - PADDLE_ENFORCE_EQ(predict_width, 2, "Only support binary classification"); + PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_width, 2, + "Only support binary classification"); auto predict_height = ctx->GetInputDim("Predict")[0]; auto label_height = ctx->GetInputDim("Label")[0]; diff --git a/paddle/fluid/operators/smooth_l1_loss_op.cc b/paddle/fluid/operators/smooth_l1_loss_op.cc index 5282bcbc693..5af47b0f6d0 100644 --- a/paddle/fluid/operators/smooth_l1_loss_op.cc +++ b/paddle/fluid/operators/smooth_l1_loss_op.cc @@ -43,8 +43,25 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"), "If weights are provided, must specify both " "inside and outside weights."); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims); + auto dims = ctx->GetInputDim("InsideWeight"); + bool check = true; + if ((!ctx->IsRuntime()) && + (framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) { + check = false; + } + if (check) { + PADDLE_ENFORCE_EQ(dims, x_dims); + } + + dims = ctx->GetInputDim("OutsideWeight"); + check = true; + if ((!ctx->IsRuntime()) && + (framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) { + check = false; + } + if (check) { + PADDLE_ENFORCE_EQ(dims, x_dims); + } } ctx->SetOutputDim("Diff", x_dims); -- GitLab