提交 ee74f7dc 编写于 作者: H huzhiqiang 提交者: GitHub

[BUG FIX] fix look_uptable OP (#2911)

上级 3e9d2196
...@@ -67,22 +67,22 @@ void LookupTableCompute::Run() { ...@@ -67,22 +67,22 @@ void LookupTableCompute::Run() {
REGISTER_LITE_KERNEL(lookup_table, REGISTER_LITE_KERNEL(lookup_table,
kARM, kARM,
kFloat, kAny,
kNCHW, kNCHW,
paddle::lite::kernels::arm::LookupTableCompute, paddle::lite::kernels::arm::LookupTableCompute,
def) def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(lookup_table_v2, REGISTER_LITE_KERNEL(lookup_table_v2,
kARM, kARM,
kFloat, kAny,
kNCHW, kNCHW,
paddle::lite::kernels::arm::LookupTableCompute, paddle::lite::kernels::arm::LookupTableCompute,
def) def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -21,7 +21,7 @@ namespace lite { ...@@ -21,7 +21,7 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class LookupTableCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class LookupTableCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public: public:
using param_t = operators::LookupTableParam; using param_t = operators::LookupTableParam;
......
...@@ -53,7 +53,7 @@ void lookup_table_compute_ref(const operators::LookupTableParam &param) { ...@@ -53,7 +53,7 @@ void lookup_table_compute_ref(const operators::LookupTableParam &param) {
TEST(lookup_table_arm, retrieve_op) { TEST(lookup_table_arm, retrieve_op) {
auto lookup_table = auto lookup_table =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>( KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kAny)>(
"lookup_table"); "lookup_table");
ASSERT_FALSE(lookup_table.empty()); ASSERT_FALSE(lookup_table.empty());
ASSERT_TRUE(lookup_table.front()); ASSERT_TRUE(lookup_table.front());
...@@ -61,7 +61,7 @@ TEST(lookup_table_arm, retrieve_op) { ...@@ -61,7 +61,7 @@ TEST(lookup_table_arm, retrieve_op) {
TEST(lookup_table_arm, init) { TEST(lookup_table_arm, init) {
LookupTableCompute lookup_table; LookupTableCompute lookup_table;
ASSERT_EQ(lookup_table.precision(), PRECISION(kFloat)); ASSERT_EQ(lookup_table.precision(), PRECISION(kAny));
ASSERT_EQ(lookup_table.target(), TARGET(kARM)); ASSERT_EQ(lookup_table.target(), TARGET(kARM));
} }
...@@ -112,4 +112,4 @@ TEST(lookup_table_arm, compute) { ...@@ -112,4 +112,4 @@ TEST(lookup_table_arm, compute) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(lookup_table, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(lookup_table, kARM, kAny, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册