diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index 244467d62492bc3017ebdb6144b49ccb9fcd30c1..67fc64ab9dfda75b64d181bcf142b41c13f91da5 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -130,5 +130,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR) lstm.cc clip.cc pixel_shuffle.cc + scatter.cc DEPS ${lite_kernel_deps} context tensor) endif() diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index 2e52bd1e285b7493148a5a779bffcfcfd1336722..131c1dbd37f65be7318492d2e953cb4aab19f9f6 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -54,6 +54,7 @@ #include "lite/backends/arm/math/reduce_mean.h" #include "lite/backends/arm/math/reduce_prod.h" #include "lite/backends/arm/math/scale.h" +#include "lite/backends/arm/math/scatter.h" #include "lite/backends/arm/math/sequence_expand.h" #include "lite/backends/arm/math/sequence_pool.h" #include "lite/backends/arm/math/sequence_pool_grad.h" diff --git a/lite/backends/arm/math/scatter.cc b/lite/backends/arm/math/scatter.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9250a9bfa3fcfbdac2a8942aeff3bd28b4bc381 --- /dev/null +++ b/lite/backends/arm/math/scatter.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/arm/math/scatter.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void scatter(const int64_t* indexs, + const float* src, + float* dst, + int index_size, + int num, + int size, + bool overwrite) { + for (int i = 0; i < num; i++) { + const float* din = src + indexs[i] * size; + memcpy(dst, din, sizeof(float) * size); + dst += size; + } + if (overwrite) { + for (int i = num; i < index_size; i++) { + const float* din = src + indexs[i] * size; + float* dout = dst + indexs[i] * size; + memcpy(dout, din, sizeof(float) * size); + } + } else { + int cnt = size >> 3; + int rem = size & 7; + for (int i = num; i < index_size; i++) { + const float* din = src + indexs[i] * size; + float* dout = dst + indexs[i] * size; + for (int j = 0; j < cnt; j++) { + float32x4_t va0 = vld1q_f32(din); + float32x4_t vb0 = vld1q_f32(dout); + float32x4_t va1 = vld1q_f32(din + 4); + float32x4_t vb1 = vld1q_f32(dout + 4); + vb0 = vaddq_f32(va0, vb0); + vb1 = vaddq_f32(va1, vb1); + din += 8; + vst1q_f32(dout, vb0); + vst1q_f32(dout + 4, vb0); + dout += 8; + } + for (int j = 0; j < rem; j++) { + dout[0] += *din++; + dout++; + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/scatter.h b/lite/backends/arm/math/scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..3d145367189eb61e7fdfbd5b20a55f5397ae702b --- /dev/null +++ b/lite/backends/arm/math/scatter.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void scatter(const int64_t* indexs, + const T* updates, + T* dst, + int index_size, + int num, + int size, + bool overwrite); +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index ad5988c10bd7650f3fcb9c759c73117954d22dd7..83789070ccc9db18a6299a9c2b14b3248dbcc28f 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -79,8 +79,10 @@ add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposal add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(pixel_shuffle_compute_arm ARM extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(scatter_compute_arm ARM extra SRCS scatter_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_expand_as_compute_arm ARM extra SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps} math_arm) + # for OCR specific add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/scatter_compute.cc b/lite/kernels/arm/scatter_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d3a512975c26d356405deb8ae9ff58093507425 --- /dev/null +++ b/lite/kernels/arm/scatter_compute.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/scatter_compute.h" +#include "lite/backends/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ScatterCompute::Run() { + auto& param = this->template Param(); + const float* updates_data = param.updates->template data(); + const int64_t* indexs_data = param.indexs->template data(); + float* output_data = param.output->template mutable_data(); + bool overwrite = param.overwrite; + int index_size = param.indexs->dims()[0]; + auto in_dims = param.x->dims(); + int num = 1; + for (int i = 1; i < in_dims.size(); i++) { + num *= in_dims[i]; + } + lite::arm::math::scatter(indexs_data, + updates_data, + output_data, + index_size, + in_dims[0], + num, + overwrite); + if (!param.x->lod().empty()) { + param.output->set_lod(param.x->lod()); + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(scatter, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ScatterCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindInput("Updates", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .Finalize(); diff --git a/lite/kernels/arm/scatter_compute.h b/lite/kernels/arm/scatter_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..5ee37cf55dd3e9f81582ffdcc5bdf96fa8cc25a8 --- /dev/null +++ b/lite/kernels/arm/scatter_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ScatterCompute : public KernelLite { + public: + void Run() override; + + virtual ~ScatterCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 02377aad498a47cff50c3a595f6fb1634a56b5ff..6cdf815a6f03f0e36b95acc4f8e6f15dc64b4de2 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -121,6 +121,7 @@ add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${ add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS}) add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS}) add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS}) +add_operator(scatter extra SRCS scatter_op.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 494ee823827fc6d71f0c41824ee7f9e52bdbb3f4..33da913d2e13d290ef42a40955c7cdc13fd855b3 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -294,6 +294,16 @@ struct ScaleParam : ParamBase { } }; +// For Scatter OP +struct ScatterParam : ParamBase { + lite::Tensor* x{}; + lite::Tensor* indexs{}; + lite::Tensor* updates{}; + lite::Tensor* output{}; + + bool overwrite{true}; +}; + // For Softmax op struct SoftmaxParam : ParamBase { lite::Tensor* x{}; diff --git a/lite/operators/scatter_op.cc b/lite/operators/scatter_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..20a0dcb6be409c87e828e168321716adf69011e4 --- /dev/null +++ b/lite/operators/scatter_op.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/scatter_op.h" +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool ScatterOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool ScatterOp::InferShapeImpl() const { + auto index_dims = param_.indexs->dims(); + auto update_dims = param_.updates->dims(); + auto input_dims = param_.x->dims(); + for (int i = 1; i < update_dims.size(); i++) { + CHECK_EQ_OR_FALSE(update_dims[i], input_dims[i]); + } + CHECK_EQ_OR_FALSE(index_dims.size(), 1L); + param_.output->Resize(input_dims); + return true; +} + +bool ScatterOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + AttachParam(¶m_); + auto x = op_desc.Input("X").front(); + auto indexs = op_desc.Input("Ids").front(); + auto updates = op_desc.Input("Updates").front(); + auto output = op_desc.Output("Out").front(); + if (op_desc.HasAttr("overwrite")) { + param_.overwrite = op_desc.GetAttr("overwrite"); + } else { + param_.overwrite = true; + } + param_.x = scope->FindVar(x)->GetMutable(); + param_.indexs = scope->FindVar(indexs)->GetMutable(); + param_.updates = scope->FindVar(updates)->GetMutable(); + param_.output = scope->FindMutableTensor(output); + + CHECK(param_.x); + CHECK(param_.indexs); + CHECK(param_.updates); + CHECK(param_.output); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(scatter, paddle::lite::operators::ScatterOp); diff --git a/lite/operators/scatter_op.h b/lite/operators/scatter_op.h new file mode 100644 index 0000000000000000000000000000000000000000..419a5308ef76ee99987945dffb50549ca6bd4842 --- /dev/null +++ b/lite/operators/scatter_op.h @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ScatterOp : public OpLite { + public: + ScatterOp() {} + explicit ScatterOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "Scatter"; } + +#ifdef LITE_WITH_PROFILE + void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) { + ch->input_shape = ch->DimToStr(param_.x->dims()); + ch->output_shape = ch->DimToStr(param_.output->dims()); + ch->macs = param_.x->numel() * 1.f; + } +#endif + + private: + mutable ScatterParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/tests/cv/anakin/bgra_to_tensor_hwc.cc b/lite/tests/cv/anakin/bgra_to_tensor_hwc.cc index daab2f3ce559cd9583839918c79bf50109275d71..4e24f87a1d8bbddf00f898185b71b8bd312f902c 100644 --- a/lite/tests/cv/anakin/bgra_to_tensor_hwc.cc +++ b/lite/tests/cv/anakin/bgra_to_tensor_hwc.cc @@ -30,7 +30,7 @@ void bgra_to_tensor_hwc(const uint8_t* bgr, float b_scales = scales[2]; int dim8 = width >> 3; - int remain = wwidth - (dim8 << 3); + int remain = width - (dim8 << 3); float32x4_t vrmean = vdupq_n_f32(r_means); float32x4_t vgmean = vdupq_n_f32(g_means); diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index b5ffe94cee83c5a51ccaf9e1d98b53bae2a49020..00fec722eb926e27492ad9c2dbeb4bff754a56de 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -66,6 +66,7 @@ if(LITE_BUILD_EXTRA) lite_cc_test(test_kernel_ctc_align_compute SRCS ctc_align_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_clip_compute SRCS clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pixel_shuffle_compute SRCS pixel_shuffle_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_scatter_compute SRCS scatter_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_expand_as_compute SRCS sequence_expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) # for training kernel diff --git a/lite/tests/kernels/scatter_compute_test.cc b/lite/tests/kernels/scatter_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2d82b38d986deafb619d61e97e20be759c48b98 --- /dev/null +++ b/lite/tests/kernels/scatter_compute_test.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +void scatter_basic(const int64_t* indexs, + const float* src, + float* dst, + int index_size, + int num, + int size, + bool overwrite) { + for (int i = 0; i < num; i++) { + const float* din = src + indexs[i] * size; + memcpy(dst, din, sizeof(float) * size); + dst += size; + } + if (overwrite) { + for (int i = num; i < index_size; i++) { + const float* din = src + indexs[i] * size; + float* dout = dst + indexs[i] * size; + memcpy(dout, din, sizeof(float) * size); + } + } else { + for (int i = num; i < index_size; i++) { + const float* din = src + indexs[i] * size; + float* dout = dst + indexs[i] * size; + for (int j = 0; j < size; j++) { + dout[j] += din[j]; + } + } + } +} + +class ScatterComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string input_ = "x"; + std::string indexs_ = "indexs"; + std::string updates_ = "updates"; + std::string output_ = "out"; + DDim up_dims_{{1}}; + DDim id_dims_{{1}}; + DDim x_dims_{{1}}; + int index_size_ = 0; + bool overwrite_ = false; + + public: + ScatterComputeTester(const Place& place, + const std::string& alias, + DDim up_dims, + DDim id_dims, + DDim x_dims, + bool overwrite, + int index_size) + : TestCase(place, alias), + up_dims_(up_dims), + id_dims_(id_dims), + x_dims_(x_dims), + index_size_(index_size), + overwrite_(overwrite) {} + + void RunBaseline(Scope* scope) override { + auto* indexs_t = scope->FindMutableTensor(indexs_); + auto* updates_t = scope->FindMutableTensor(updates_); + const auto* indexs_data = indexs_t->data(); + const auto* updates_data = updates_t->data(); + auto* out = scope->NewTensor(output_); + + out->Resize(x_dims_); + + auto* out_data = out->mutable_data(); + int in_n = x_dims_[0]; + int in_c = x_dims_[1]; + int in_h = x_dims_[2]; + int in_w = x_dims_[3]; + int size = in_c * in_h * in_w; + + scatter_basic(indexs_data, + updates_data, + out_data, + index_size_, + in_n, + size, + overwrite_); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("scatter"); + op_desc->SetInput("X", {input_}); + op_desc->SetInput("Ids", {indexs_}); + op_desc->SetInput("Updates", {updates_}); + op_desc->SetOutput("Out", {output_}); + op_desc->SetAttr("overwrite", overwrite_); + } + + void PrepareData() override { + std::vector data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + data[i] = i * 1.0; + } + SetCommonTensor(input_, x_dims_, data.data()); + std::vector update(up_dims_.production()); + for (int i = 0; i < up_dims_.production(); i++) { + update[i] = i * 1.0; + } + SetCommonTensor(updates_, up_dims_, update.data()); + std::vector index(id_dims_.production()); + for (int i = 0; i < id_dims_.production(); i++) { + index[i] = i; + } + SetCommonTensor(indexs_, id_dims_, index.data()); + } +}; + +void test_scatter(Place place) { + for (auto n : {1, 3}) { + for (auto c : {1, 2}) { + for (auto h : {1, 3}) { + for (auto w : {1, 3}) { + for (bool overwrite : {false, true}) { + auto x_dims = DDim(std::vector({n, c, h, w})); + auto up_dims = DDim(std::vector({n, c, h, w})); + auto id_dims = DDim(std::vector({n})); + std::unique_ptr tester(new ScatterComputeTester( + place, "def", up_dims, id_dims, x_dims, overwrite, n)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } +} + +TEST(Scatter, precision) { +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_scatter(place); +#endif +} + +} // namespace lite +} // namespace paddle