diff --git a/lite/kernels/arm/gather_compute.cc b/lite/kernels/arm/gather_compute.cc index f5a87e5431955252e47143252ce13ba4056c4a7f..84e1b5dd5c7268337a5c0d50b53d209ecfbc73f2 100644 --- a/lite/kernels/arm/gather_compute.cc +++ b/lite/kernels/arm/gather_compute.cc @@ -73,10 +73,10 @@ void GatherCompute::Run() { REGISTER_LITE_KERNEL(gather, kARM, - kAny, + kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, - def) + int32) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("Index", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) @@ -85,10 +85,10 @@ REGISTER_LITE_KERNEL(gather, REGISTER_LITE_KERNEL(gather, kARM, - kAny, + kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, - def_int64_idx) + int64) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("Index", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) diff --git a/lite/kernels/arm/gather_compute.h b/lite/kernels/arm/gather_compute.h index 0226e5f68eee3f23dbd945af6f4f455ab79190c5..fc68a982bee3357635bfd40bd83589bd1846a747 100644 --- a/lite/kernels/arm/gather_compute.h +++ b/lite/kernels/arm/gather_compute.h @@ -24,7 +24,7 @@ namespace kernels { namespace arm { template -class GatherCompute : public KernelLite { +class GatherCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/host/compare_compute.cc b/lite/kernels/host/compare_compute.cc index b45cdc789ba18c6c5abb08dce73bce83990ee5ca..242c6c83d027a0ba8c8c7c8d6f028550f77af752 100644 --- a/lite/kernels/host/compare_compute.cc +++ b/lite/kernels/host/compare_compute.cc @@ -230,6 +230,21 @@ REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_float, def) TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) .Finalize(); +using greater_than_int64 = paddle::lite::kernels::host::CompareCompute< + PRECISION(kInt64), + paddle::lite::kernels::host::_GreaterThanFunctor>; +REGISTER_LITE_KERNEL(greater_than, kHost, kInt64, kAny, greater_than_int64, def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindInput("Y", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) + .Finalize(); + using greater_equal_float = paddle::lite::kernels::host::CompareCompute< PRECISION(kFloat), paddle::lite::kernels::host::_GreaterEqualFunctor>; @@ -245,3 +260,19 @@ REGISTER_LITE_KERNEL( {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) .Finalize(); + +using greater_equal_int64 = paddle::lite::kernels::host::CompareCompute< + PRECISION(kInt64), + paddle::lite::kernels::host::_GreaterEqualFunctor>; +REGISTER_LITE_KERNEL( + greater_equal, kHost, kInt64, kAny, greater_equal_float, def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindInput("Y", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) + .Finalize(); diff --git a/lite/tests/kernels/gather_compute_test.cc b/lite/tests/kernels/gather_compute_test.cc index 59be5b973a46f17f924b4fb533eabe33534af93e..3f93627c03e3edfce3fc2511ae320571f68b8598 100644 --- a/lite/tests/kernels/gather_compute_test.cc +++ b/lite/tests/kernels/gather_compute_test.cc @@ -53,9 +53,9 @@ class GatherComputeTest : public arena::TestCase { out_dims[0] = batch_size; out->Resize(out_dims); - auto x_data = x->data(); - auto index_data = index->data(); - auto out_data = out->mutable_data(); + auto x_data = x->data(); + auto index_data = index->data(); + auto out_data = out->mutable_data(); auto slice_num = x_dims[0]; auto slice_size = x_dims.Slice(1, x_dims.size()).production(); @@ -66,7 +66,7 @@ class GatherComputeTest : public arena::TestCase { CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index; memcpy(out_data + i * slice_size, x_data + index * slice_size, - slice_size * sizeof(float)); + slice_size * sizeof(int64_t)); } } @@ -78,11 +78,11 @@ class GatherComputeTest : public arena::TestCase { } void PrepareData() override { - std::vector x(x_dims_.production()); - fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production()); + std::vector x(x_dims_.production()); + fill_data_rand(x.data(), int64_t(-1), int64_t(1), x_dims_.production()); - std::vector index(index_dims_.production()); - fill_data_rand( + std::vector index(index_dims_.production()); + fill_data_rand( index.data(), 0, x_dims_[0] - 1, index_dims_.production()); SetCommonTensor(x_, x_dims_, x.data()); @@ -110,8 +110,8 @@ TEST(Gather, precision) { for (auto x_dims : std::vector>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) { for (auto index_dims : std::vector>{{3}, {7}, {10}}) { - std::unique_ptr tester( - new GatherComputeTest(place, "def", DDim(x_dims), DDim(index_dims))); + std::unique_ptr tester(new GatherComputeTest( + place, "int64", DDim(x_dims), DDim(index_dims))); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); }