Created by: songyouwei
dygraph.BilinearTensorProduct
Python API类型检查增强
当此动态图API在静态图下运行时:
检查input类型是否为Variable 检查数据类型是否为float32, float64 异常情况示例如下:
import paddle.fluid as fluid
import numpy as np
layer = fluid.dygraph.nn.BilinearTensorProduct(
input1_dim=5, input2_dim=4, output_dim=1000)
# the input must be Variable.
x0 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
layer(x0, x0)
# TypeError: The type of 'x' in BilinearTensorProduct must be <class 'paddle.fluid.framework.Variable'>, but received <class 'paddle.fluid.core_avx.LoDTensor'>.
x1 = fluid.data(name='x1', shape=[-1, 5], dtype="float16")
x2 = fluid.data(name='x2', shape=[-1, 4], dtype="float32")
layer(x1, x2)
# TypeError: The data type of 'x' in BilinearTensorProduct must be ['float32', 'float64'], but received float16.