Created by: chenwhql
dygraph.LayerNorm Python API类型检查增强
当此动态图API在静态图下运行时:
-
- 检查input类型是否为Variable
-
- 检查数据类型是否为float32, float64
两个异常情况示例如下(可手动执行):
- type error
import paddle.fluid as fluid
import numpy as np
layer_norm = fluid.LayerNorm([32, 32])
x1 = np.random.random((3, 32, 32)).astype('float32')
layer_norm(x1)
# TypeError: The type of 'input' in LayerNorm must be <class 'paddle.fluid.framework.Variable'>, but received <class 'numpy.ndarray'>.
- dtype error
import paddle.fluid as fluid
import numpy as np
layer_norm = fluid.LayerNorm([32, 32])
x2 = fluid.layers.data(name='x2', shape=[3, 32, 32], dtype="int32")
layer_norm(x2)
# TypeError: The data type of 'input' in LayerNorm must be ['float32', 'float64'], but received int32.