diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 8d6eae60be46ea1e06763f9a337e61f8b7edfe1f..51eaccc0583184b53e4be49a4b2fb15252b99364 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -20,6 +20,7 @@ from ..layers import utils from ..layers import nn from .. import dygraph_utils from . import layers +from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator from ..param_attr import ParamAttr from ..initializer import Normal, Constant, NumpyArrayInitializer @@ -1147,6 +1148,9 @@ class BatchNorm(layers.Layer): return dygraph_utils._append_activation_in_dygraph( batch_norm_out, act=self._act) + check_variable_and_dtype(input, 'input', + ['float16', 'float32', 'float64'], 'BatchNorm') + attrs = { "momentum": self._momentum, "epsilon": self._epsilon, diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 519ec1e4ab026873a54f5e245d0bec49f9010149..7beea896da0549526027fedfdba0bd1ddfa78ba5 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -608,5 +608,20 @@ class TestBatchNormOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.batch_norm, x2) +class TestDygraphBatchNormAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + batch_norm = fluid.dygraph.BatchNorm(10) + # the input of BatchNorm must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + self.assertRaises(TypeError, batch_norm, x1) + + # the input dtype of BatchNorm must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="int32") + self.assertRaises(TypeError, batch_norm, x2) + + if __name__ == '__main__': unittest.main()