From 8caa785e83c24fdddfa23de6b3123a2ad890932c Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 20 Mar 2019 23:03:28 +0800 Subject: [PATCH] Enhance affine_channel_op infer-shape check (#16317) * Enhance affine_channel_op infer-shape check --- paddle/fluid/operators/affine_channel_op.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 8944a74967..8f1adab894 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -67,6 +67,22 @@ class AffineChannelOp : public framework::OperatorWithKernel { "Input(Bias) of AffineChannelOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of AffineChannelOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto scale_dims = ctx->GetInputDim("Scale"); + auto b_dims = ctx->GetInputDim("Bias"); + const framework::DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + + const int64_t C = (data_layout == framework::DataLayout::kNCHW + ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL); + PADDLE_ENFORCE_EQ(scale_dims[0], C); + PADDLE_ENFORCE_EQ(b_dims.size(), 1UL); + PADDLE_ENFORCE_EQ(b_dims[0], C); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", "Out"); } -- GitLab