Created by: songyouwei
dygraph.GRUUnit
Python API类型检查增强
当此动态图API在静态图下运行时:
检查input类型是否为Variable 检查数据类型是否为float32/float64 异常情况示例如下:
import paddle.fluid as fluid
import numpy as np
D = 5
layer = fluid.dygraph.nn.GRUUnit(size=D * 3)
# the input must be Variable.
x0 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
layer(x0, x0)
# TypeError: The type of 'input' in GRUUnit must be <class 'paddle.fluid.framework.Variable'>, but received <class 'paddle.fluid.core_avx.LoDTensor'>.
x = fluid.data(name='x', shape=[-1, D * 3], dtype='float16')
hidden = fluid.data(name='hidden', shape=[-1, D], dtype='float32')
layer(x, hidden)
# TypeError: The data type of 'input' in GRUUnit must be ['float32', 'float64'], but received float16.