Created by: songyouwei
dygraph.Embedding
Python API类型检查增强
当此动态图API在静态图下运行时:
检查input类型是否为Variable 检查数据类型是否为int64 异常情况示例如下:
import paddle.fluid as fluid
import numpy as np
dict_size = 20
layer = fluid.dygraph.nn.Embedding(
size=[dict_size, 32], param_attr='emb.w', is_sparse=False)
x0 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
layer(x0)
# TypeError: The type of 'input' in Embedding must be <class 'paddle.fluid.framework.Variable'>, but received <class 'paddle.fluid.core_avx.LoDTensor'>.
data_t = fluid.data(name='word', shape=[1], dtype='int32')
layer(data_t)
# TypeError: The data type of 'input' in Embedding must be ['int64'], but received int32.