diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 8944a749674c3ba6c83526e4d66f449075716f43..8f1adab894c2c20107a0c2494a0a57fff0e26dbd 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -67,6 +67,22 @@ class AffineChannelOp : public framework::OperatorWithKernel { "Input(Bias) of AffineChannelOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of AffineChannelOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto scale_dims = ctx->GetInputDim("Scale"); + auto b_dims = ctx->GetInputDim("Bias"); + const framework::DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + + const int64_t C = (data_layout == framework::DataLayout::kNCHW + ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL); + PADDLE_ENFORCE_EQ(scale_dims[0], C); + PADDLE_ENFORCE_EQ(b_dims.size(), 1UL); + PADDLE_ENFORCE_EQ(b_dims[0], C); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", "Out"); }