From d349a622f04dea7208b099dae7cd18f02c17f3c0 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 8 Apr 2020 11:06:17 +0800 Subject: [PATCH] api dygraph batch norm type check, test=develop (#23525) --- python/paddle/fluid/dygraph/nn.py | 4 ++++ .../fluid/tests/unittests/test_batch_norm_op.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 8d6eae60be4..51eaccc0583 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -20,6 +20,7 @@ from ..layers import utils from ..layers import nn from .. import dygraph_utils from . import layers +from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator from ..param_attr import ParamAttr from ..initializer import Normal, Constant, NumpyArrayInitializer @@ -1147,6 +1148,9 @@ class BatchNorm(layers.Layer): return dygraph_utils._append_activation_in_dygraph( batch_norm_out, act=self._act) + check_variable_and_dtype(input, 'input', + ['float16', 'float32', 'float64'], 'BatchNorm') + attrs = { "momentum": self._momentum, "epsilon": self._epsilon, diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 519ec1e4ab0..7beea896da0 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -608,5 +608,20 @@ class TestBatchNormOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.batch_norm, x2) +class TestDygraphBatchNormAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + batch_norm = fluid.dygraph.BatchNorm(10) + # the input of BatchNorm must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + self.assertRaises(TypeError, batch_norm, x1) + + # the input dtype of BatchNorm must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="int32") + self.assertRaises(TypeError, batch_norm, x2) + + if __name__ == '__main__': unittest.main() -- GitLab