未验证 提交 e4ba5f86 编写于 作者: G gouzil 提交者: GitHub

[Divide by 0 Error] add DataNormKernel check (#51583)

上级 579fb5fd
......@@ -284,6 +284,16 @@ class DataNormKernel<phi::CPUContext, T> : public framework::OpKernel<T> {
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_LT(0,
N,
platform::errors::InvalidArgument(
"The dims of Input(X) should be greater than 0."));
PADDLE_ENFORCE_LT(0,
C,
platform::errors::InvalidArgument(
"The dims of Input(X) should be greater than 0."));
auto *y = ctx.Output<phi::DenseTensor>("Y");
auto *mean_out = ctx.Output<phi::DenseTensor>("Means");
auto *scales = ctx.Output<phi::DenseTensor>("Scales");
......
......@@ -114,6 +114,16 @@ class DataNormKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
platform::errors::PreconditionNotMet("The Input dim size should be 2"));
const int N = x_dims[0];
const int C = x_dims[1];
PADDLE_ENFORCE_LT(0,
N,
platform::errors::InvalidArgument(
"The dims of Input(X) should be greater than 0."));
PADDLE_ENFORCE_LT(0,
C,
platform::errors::InvalidArgument(
"The dims of Input(X) should be greater than 0."));
const T *batch_size_in =
ctx.Input<phi::DenseTensor>("BatchSize")->data<T>();
const T *batch_sum_in = ctx.Input<phi::DenseTensor>("BatchSum")->data<T>();
......
......@@ -19,6 +19,7 @@ import numpy as np
from eager_op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from paddle.fluid.op import Operator
......@@ -538,6 +539,23 @@ class TestDataNormOpErrorr(unittest.TestCase):
x3 = paddle.static.data("", shape=[0], dtype="float32")
self.assertRaises(ValueError, paddle.static.nn.data_norm, x3)
# The size of input in data_norm should not be 0.
def test_0_size():
paddle.enable_static()
x = fluid.data(name='x', shape=[0, 3], dtype='float32')
out = paddle.static.nn.data_norm(x, slot_dim=1)
cpu = fluid.core.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(fluid.default_startup_program())
test_program = fluid.default_main_program().clone(for_test=True)
exe.run(
test_program,
fetch_list=out,
feed={'x': np.ones([0, 3]).astype('float32')},
)
self.assertRaises(ValueError, test_0_size)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册