提交 e0285eae 编写于 作者: Z Zhang Ting 提交者: Aurelius84

add check for input channels and Attr(groups), test=develop (#21095)

上级 dcf371b6
......@@ -82,6 +82,14 @@ class MaxOutOp : public framework::OperatorWithKernel {
// check groups > 1
PADDLE_ENFORCE_GT(groups, 1,
"Attr(groups) of Op(maxout) should be larger than 1.");
PADDLE_ENFORCE_EQ(
in_x_dims[axis] % groups, 0,
"ValueError: The number of input channels for Op(maxout) "
"should be divisible by Attr(groups). But received: the "
"input's channels is [%d], the shape of input is [%s], "
"the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
"error may come from wrong Attr(groups) or Attr(axis) setting.",
in_x_dims[axis], in_x_dims, groups, axis);
std::vector<int64_t> output_shape(
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册