未验证 提交 d349a622 编写于 作者: C Chen Weihang 提交者: GitHub

api dygraph batch norm type check, test=develop (#23525)

上级 8d95a109
...@@ -20,6 +20,7 @@ from ..layers import utils ...@@ -20,6 +20,7 @@ from ..layers import utils
from ..layers import nn from ..layers import nn
from .. import dygraph_utils from .. import dygraph_utils
from . import layers from . import layers
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..initializer import Normal, Constant, NumpyArrayInitializer from ..initializer import Normal, Constant, NumpyArrayInitializer
...@@ -1147,6 +1148,9 @@ class BatchNorm(layers.Layer): ...@@ -1147,6 +1148,9 @@ class BatchNorm(layers.Layer):
return dygraph_utils._append_activation_in_dygraph( return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act) batch_norm_out, act=self._act)
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'BatchNorm')
attrs = { attrs = {
"momentum": self._momentum, "momentum": self._momentum,
"epsilon": self._epsilon, "epsilon": self._epsilon,
......
...@@ -608,5 +608,20 @@ class TestBatchNormOpError(unittest.TestCase): ...@@ -608,5 +608,20 @@ class TestBatchNormOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.batch_norm, x2) self.assertRaises(TypeError, fluid.layers.batch_norm, x2)
class TestDygraphBatchNormAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
batch_norm = fluid.dygraph.BatchNorm(10)
# the input of BatchNorm must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
self.assertRaises(TypeError, batch_norm, x1)
# the input dtype of BatchNorm must be float16 or float32 or float64
# float16 only can be set on GPU place
x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="int32")
self.assertRaises(TypeError, batch_norm, x2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册