未验证 提交 f1c68a08 编写于 作者: G gongweibao 提交者: GitHub

add int64 support test=develop (#32736)

add int64 support
上级 9599c3b3
......@@ -29,11 +29,6 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
auto *output_t = ctx.Output<framework::LoDTensor>("Out"); // float tensor
auto *table_t = ctx.Input<framework::LoDTensor>("W");
// It seems cann 20.1 accepts int64, but cann 20.2+ not.
PADDLE_ENFORCE_EQ(ids_t->type(), framework::proto::VarType::INT32,
platform::errors::Unimplemented(
"The index of LookupTableV2 should be int32."));
auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ(
table_var->IsType<framework::LoDTensor>(), true,
......
......@@ -41,7 +41,7 @@ class TestLookupTableV2(OpTest):
vocab = 10
dim = 20
w = np.ones([vocab, dim]).astype(self.dtype)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64)
out = np.ones([bsz, seqlen, dim]).astype(self.dtype)
self.inputs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册