From a04f30e7cf777964221f26eef2cf4a837c80f622 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 27 Dec 2017 17:51:59 +0800 Subject: [PATCH] move ENFORCE position --- paddle/operators/conv_op.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index ab52a41b5..e65a5dce5 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -31,8 +31,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector paddings = ctx->Attrs().Get>("paddings"); int groups = ctx->Attrs().Get("groups"); std::vector dilations = ctx->Attrs().Get>("dilations"); - int input_channels = in_dims[1]; - int output_channels = filter_dims[0]; PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, "Conv intput should be 4-D or 5-D tensor."); @@ -45,9 +43,13 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ( paddings.size(), strides.size(), "Conv paddings dimension and Conv strides dimension should be the same."); + + int input_channels = in_dims[1]; PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, "The number of input channels should be equal to filter " "channels * groups."); + + int output_channels = filter_dims[0]; PADDLE_ENFORCE_EQ( output_channels % groups, 0, "The number of output channels should be divided by groups."); -- GitLab