From 87d89dfe1464b1a361dbf6b77e0a253ad667709a Mon Sep 17 00:00:00 2001 From: ceci3 <592712189@qq.com> Date: Sun, 14 Apr 2019 04:01:56 +0000 Subject: [PATCH] fix batch_norm and cos_sim infer shape, test=develop --- paddle/fluid/operators/batch_norm_op.cc | 18 ++++++++++---- paddle/fluid/operators/cos_sim_op.cc | 32 ++++++++++++++++--------- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index c0ad959309a..5f5a1f70f2a 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -65,11 +65,21 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], C); + auto scale_dim = ctx->GetInputDim("Scale"); + auto bias_dim = ctx->GetInputDim("Bias"); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 || + framework::product(bias_dim) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL); + PADDLE_ENFORCE_EQ(scale_dim[0], C); + PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL); + PADDLE_ENFORCE_EQ(scale_dim[0], C); + } ctx->SetOutputDim("Y", x_dims); ctx->SetOutputDim("MeanOut", {C}); ctx->SetOutputDim("VarianceOut", {C}); diff --git a/paddle/fluid/operators/cos_sim_op.cc b/paddle/fluid/operators/cos_sim_op.cc index 8f3644039f9..920d087e429 100644 --- a/paddle/fluid/operators/cos_sim_op.cc +++ b/paddle/fluid/operators/cos_sim_op.cc @@ -40,17 +40,27 @@ class CosSimOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), - "Ranks of Input(X) and Input(Y) must be equal."); - PADDLE_ENFORCE_GE(x_dims.size(), 2, - "Rank of Input(X) must not be less than 2."); - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()), - framework::slice_ddim(y_dims, 1, y_dims.size()), - "All dimensions except the 1st of Input(X) and Input(Y) " - "must be equal."); - PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, - "The 1st dimension of Input(Y) must be equal to Input(X) or" - " just 1 (which will be broadcasted to match Input(X))."); + bool check = true; + if ((!ctx->IsRuntime()) && + (framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), + "Ranks of Input(X) and Input(Y) must be equal."); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "Rank of Input(X) must not be less than 2."); + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, 1, x_dims.size()), + framework::slice_ddim(y_dims, 1, y_dims.size()), + "All dimensions except the 1st of Input(X) and Input(Y) " + "must be equal."); + PADDLE_ENFORCE( + x_dims[0] == y_dims[0] || y_dims[0] == 1, + "The 1st dimension of Input(Y) must be equal to Input(X) or" + " just 1 (which will be broadcasted to match Input(X))."); + } // resize tensor ctx->SetOutputDim("Out", {x_dims[0], 1}); -- GitLab