From 4b5b54031a5cf14a8193c049d73d60dcb7a52c55 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Fri, 25 Sep 2020 00:38:55 +0800 Subject: [PATCH] [cherry-pick] [ARM] Add int64 implement for `gather` and `greater_than` (#4444) --- lite/kernels/arm/gather_compute.cc | 8 +++--- lite/kernels/arm/gather_compute.h | 2 +- lite/kernels/host/compare_compute.cc | 31 +++++++++++++++++++++++ lite/tests/kernels/gather_compute_test.cc | 20 +++++++-------- 4 files changed, 46 insertions(+), 15 deletions(-) diff --git a/lite/kernels/arm/gather_compute.cc b/lite/kernels/arm/gather_compute.cc index f5a87e5431..84e1b5dd5c 100644 --- a/lite/kernels/arm/gather_compute.cc +++ b/lite/kernels/arm/gather_compute.cc @@ -73,10 +73,10 @@ void GatherCompute::Run() { REGISTER_LITE_KERNEL(gather, kARM, - kAny, + kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, - def) + int32) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("Index", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) @@ -85,10 +85,10 @@ REGISTER_LITE_KERNEL(gather, REGISTER_LITE_KERNEL(gather, kARM, - kAny, + kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, - def_int64_idx) + int64) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("Index", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) diff --git a/lite/kernels/arm/gather_compute.h b/lite/kernels/arm/gather_compute.h index 0226e5f68e..fc68a982be 100644 --- a/lite/kernels/arm/gather_compute.h +++ b/lite/kernels/arm/gather_compute.h @@ -24,7 +24,7 @@ namespace kernels { namespace arm { template -class GatherCompute : public KernelLite { +class GatherCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/host/compare_compute.cc b/lite/kernels/host/compare_compute.cc index b45cdc789b..242c6c83d0 100644 --- a/lite/kernels/host/compare_compute.cc +++ b/lite/kernels/host/compare_compute.cc @@ -230,6 +230,21 @@ REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_float, def) TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) .Finalize(); +using greater_than_int64 = paddle::lite::kernels::host::CompareCompute< + PRECISION(kInt64), + paddle::lite::kernels::host::_GreaterThanFunctor>; +REGISTER_LITE_KERNEL(greater_than, kHost, kInt64, kAny, greater_than_int64, def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindInput("Y", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) + .Finalize(); + using greater_equal_float = paddle::lite::kernels::host::CompareCompute< PRECISION(kFloat), paddle::lite::kernels::host::_GreaterEqualFunctor>; @@ -245,3 +260,19 @@ REGISTER_LITE_KERNEL( {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) .Finalize(); + +using greater_equal_int64 = paddle::lite::kernels::host::CompareCompute< + PRECISION(kInt64), + paddle::lite::kernels::host::_GreaterEqualFunctor>; +REGISTER_LITE_KERNEL( + greater_equal, kHost, kInt64, kAny, greater_equal_float, def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindInput("Y", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) + .Finalize(); diff --git a/lite/tests/kernels/gather_compute_test.cc b/lite/tests/kernels/gather_compute_test.cc index 59be5b973a..3f93627c03 100644 --- a/lite/tests/kernels/gather_compute_test.cc +++ b/lite/tests/kernels/gather_compute_test.cc @@ -53,9 +53,9 @@ class GatherComputeTest : public arena::TestCase { out_dims[0] = batch_size; out->Resize(out_dims); - auto x_data = x->data(); - auto index_data = index->data(); - auto out_data = out->mutable_data(); + auto x_data = x->data(); + auto index_data = index->data(); + auto out_data = out->mutable_data(); auto slice_num = x_dims[0]; auto slice_size = x_dims.Slice(1, x_dims.size()).production(); @@ -66,7 +66,7 @@ class GatherComputeTest : public arena::TestCase { CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index; memcpy(out_data + i * slice_size, x_data + index * slice_size, - slice_size * sizeof(float)); + slice_size * sizeof(int64_t)); } } @@ -78,11 +78,11 @@ class GatherComputeTest : public arena::TestCase { } void PrepareData() override { - std::vector x(x_dims_.production()); - fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production()); + std::vector x(x_dims_.production()); + fill_data_rand(x.data(), int64_t(-1), int64_t(1), x_dims_.production()); - std::vector index(index_dims_.production()); - fill_data_rand( + std::vector index(index_dims_.production()); + fill_data_rand( index.data(), 0, x_dims_[0] - 1, index_dims_.production()); SetCommonTensor(x_, x_dims_, x.data()); @@ -110,8 +110,8 @@ TEST(Gather, precision) { for (auto x_dims : std::vector>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) { for (auto index_dims : std::vector>{{3}, {7}, {10}}) { - std::unique_ptr tester( - new GatherComputeTest(place, "def", DDim(x_dims), DDim(index_dims))); + std::unique_ptr tester(new GatherComputeTest( + place, "int64", DDim(x_dims), DDim(index_dims))); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } -- GitLab