diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index 87618b954d232dcfe5d0ed0b8062db7c324c1290..9574b325ef77fd22c2baeea1bc45469b14c597a1 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -29,11 +29,6 @@ class LookupTableV2NPUKernel : public framework::OpKernel { auto *output_t = ctx.Output("Out"); // float tensor auto *table_t = ctx.Input("W"); - // It seems cann 20.1 accepts int64, but cann 20.2+ not. - PADDLE_ENFORCE_EQ(ids_t->type(), framework::proto::VarType::INT32, - platform::errors::Unimplemented( - "The index of LookupTableV2 should be int32.")); - auto *table_var = ctx.InputVar("W"); PADDLE_ENFORCE_EQ( table_var->IsType(), true, diff --git a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py index 2463ddb7137acd683fde3ce2c5d09341a5c4a4d2..400ddd9d4aab0775af6007da36475db72561136f 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py @@ -41,7 +41,7 @@ class TestLookupTableV2(OpTest): vocab = 10 dim = 20 w = np.ones([vocab, dim]).astype(self.dtype) - x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32) + x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64) out = np.ones([bsz, seqlen, dim]).astype(self.dtype) self.inputs = {