From 054c334f69c1ac2a0332393c66f779a10f810356 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Tue, 26 Nov 2019 15:48:12 +0800 Subject: [PATCH] Add gather op on x86 platform (#2419) * add gather op on x86 platform and add its unittests, test=develop --- lite/kernels/x86/CMakeLists.txt | 2 + lite/kernels/x86/gather_compute.cc | 32 +++++ lite/kernels/x86/gather_compute.h | 99 +++++++++++++++ lite/kernels/x86/gather_compute_test.cc | 159 ++++++++++++++++++++++++ 4 files changed, 292 insertions(+) create mode 100644 lite/kernels/x86/gather_compute.cc create mode 100644 lite/kernels/x86/gather_compute.h create mode 100644 lite/kernels/x86/gather_compute_test.cc diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index b2b3bb79a4..a1d3615115 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -29,6 +29,7 @@ add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_comp # lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) # lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) +add_kernel(gather_compute_x86 X86 basic SRCS gather_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) # lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) # lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) @@ -65,6 +66,7 @@ add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kerne lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) +lite_cc_test(test_gather_compute_x86 SRCS gather_compute_test.cc DEPS gather_compute_x86) lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86) lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86) lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_batch_size_like_compute_test.cc DEPS fill_constant_batch_size_like_compute_x86) diff --git a/lite/kernels/x86/gather_compute.cc b/lite/kernels/x86/gather_compute.cc new file mode 100644 index 0000000000..836f336271 --- /dev/null +++ b/lite/kernels/x86/gather_compute.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/gather_compute.h" + +typedef paddle::lite::kernels::x86::GatherCompute GatherInt32; +typedef paddle::lite::kernels::x86::GatherCompute GatherInt64; + +REGISTER_LITE_KERNEL(gather, kX86, kFloat, kNCHW, GatherInt32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Index", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(gather, kX86, kFloat, kNCHW, GatherInt64, int64_in) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Index", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/gather_compute.h b/lite/kernels/x86/gather_compute.h new file mode 100644 index 0000000000..6ee270647f --- /dev/null +++ b/lite/kernels/x86/gather_compute.h @@ -0,0 +1,99 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/api/paddle_place.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +#include "lite/fluid/data_type.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +/** + * A thin wrapper for gathering on cpu tensor + * Return a new tensor from source tensor, gathered according to index + * input[src]: type-T source Tensor + * input[index]: type-IndexT index Tensor (1-D) + * return: output tensor + */ +template +void CPUGather(const lite::Tensor* src, + const lite::Tensor* index, + lite::Tensor* output) { + // check index of shape 1-D + if (index->dims().size() == 2) { + CHECK(index->dims()[1] == 1) << "Index(Input)'s dimension[1] should be 1 " + "when Index(input)'s dimension's size " + "equal to 2 in Gather(Op)."; + } else { + CHECK(index->dims().size() == 1) + << "Index(Input)'s dimension's size() should be 1 or 2 in Gather(Op)."; + } + int64_t index_size = index->dims()[0]; + + auto src_dims = src->dims(); + + const T* p_src = src->data(); + const IndexT* p_index = index->data(); + T* p_output = output->mutable_data(); + + // slice size + int slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + const size_t slice_bytes = slice_size * sizeof(T); + for (int64_t i = 0; i < index_size; ++i) { + int index_ = p_index[i]; + memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes); + } +} + +template +class GatherCompute : public KernelLite { + public: + using param_t = operators::GatherParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + auto x = param.X; + auto index = param.Index; + auto out = param.Out; + + out->mutable_data(); + if (x->dims().production() == 0) return; + /* + * Since there's no type defined for lite::Tensor in Paddle-Lite, then + * convert the Index's value to float which must be int32_t or int64_t and + * this supposes to cause no precision difference during inference just for + * now. + * Alternatively, if define the Tensor's type during registering, may cause + * a redefinition error. + */ + CPUGather(x, index, out); + } + + virtual ~GatherCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/gather_compute_test.cc b/lite/kernels/x86/gather_compute_test.cc new file mode 100644 index 0000000000..286dfcb08a --- /dev/null +++ b/lite/kernels/x86/gather_compute_test.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/gather_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(gather_x86, retrive_op) { + auto gather = + KernelRegistry::Global().Create( + "gather"); + ASSERT_FALSE(gather.empty()); + int cnt = 0; + for (auto item = gather.begin(); item != gather.end(); ++item) { + cnt++; + ASSERT_TRUE(*item); + } + ASSERT_EQ(cnt, 2); +} + +TEST(gather_x86, int32_init) { + GatherCompute gather; + ASSERT_EQ(gather.precision(), PRECISION(kFloat)); + ASSERT_EQ(gather.target(), TARGET(kX86)); +} + +TEST(gather_x86, int64_init) { + GatherCompute gather; + ASSERT_EQ(gather.precision(), PRECISION(kFloat)); + ASSERT_EQ(gather.target(), TARGET(kX86)); +} + +template +void test_case_1dims() { + lite::Tensor x, index, out; + std::vector x_shape{10}; + x.Resize(lite::DDim(x_shape)); + std::vector index_shape{3}; + index.Resize(lite::DDim(index_shape)); + std::vector out_shape{3}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto index_data = index.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + std::vector index_value{1, 3, 5}; + for (int i = 0; i < index.dims().production(); ++i) { + index_data[i] = static_cast(index_value[i]); + } + + GatherCompute gather; + operators::GatherParam param; + + param.X = &x; + param.Index = &index; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + gather.SetContext(std::move(ctx)); + gather.SetParam(param); + gather.Run(); + + std::vector ref_data{1, 3, 5}; + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } +} + +template +void test_case_2dims() { + lite::Tensor x, index, out; + std::vector x_shape{10, 20}; + x.Resize(lite::DDim(x_shape)); + std::vector index_shape{3}; + index.Resize(lite::DDim(index_shape)); + std::vector out_shape{3, 20}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto index_data = index.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); ++i) { + x_data[i] = static_cast(i); + } + std::vector index_value{1, 3, 5}; + for (int i = 0; i < index.dims().production(); ++i) { + index_data[i] = static_cast(index_value[i]); + } + + GatherCompute gather; + operators::GatherParam param; + + param.X = &x; + param.Index = &index; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + gather.SetContext(std::move(ctx)); + gather.SetParam(param); + gather.Run(); + + std::vector ref_data(60); + for (int i = 0; i < 20; ++i) { + ref_data[i] = static_cast(20 + i); + } + for (int i = 20; i < 40; ++i) { + ref_data[i] = static_cast(40 + i); + } + for (int i = 40; i < 60; ++i) { + ref_data[i] = static_cast(60 + i); + } + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } +} + +TEST(gather_x86, run_test_1dims) { + test_case_1dims(); + test_case_1dims(); +} + +TEST(gather_x86, run_test_2dims) { + test_case_2dims(); + test_case_2dims(); +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(gather, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(gather, kX86, kFloat, kNCHW, int64_in); -- GitLab