From a477fdc6bd1520851e8bce22848ae3614090c190 Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Wed, 12 Jun 2019 12:48:49 +0000 Subject: [PATCH] add split op and arm kernel test=develop --- paddle/fluid/lite/arm/math/CMakeLists.txt | 1 + paddle/fluid/lite/arm/math/split.cc | 82 +++++++++ paddle/fluid/lite/arm/math/split.h | 35 ++++ paddle/fluid/lite/kernels/arm/CMakeLists.txt | 3 + .../fluid/lite/kernels/arm/split_compute.cc | 46 +++++ paddle/fluid/lite/kernels/arm/split_compute.h | 35 ++++ .../lite/kernels/arm/split_compute_test.cc | 170 ++++++++++++++++++ paddle/fluid/lite/operators/CMakeLists.txt | 2 + paddle/fluid/lite/operators/op_params.h | 9 + paddle/fluid/lite/operators/split_op.cc | 82 +++++++++ paddle/fluid/lite/operators/split_op.h | 46 +++++ 11 files changed, 511 insertions(+) create mode 100644 paddle/fluid/lite/arm/math/split.cc create mode 100644 paddle/fluid/lite/arm/math/split.h create mode 100644 paddle/fluid/lite/kernels/arm/split_compute.cc create mode 100644 paddle/fluid/lite/kernels/arm/split_compute.h create mode 100644 paddle/fluid/lite/kernels/arm/split_compute_test.cc create mode 100644 paddle/fluid/lite/operators/split_op.cc create mode 100644 paddle/fluid/lite/operators/split_op.h diff --git a/paddle/fluid/lite/arm/math/CMakeLists.txt b/paddle/fluid/lite/arm/math/CMakeLists.txt index 7708fe80256..2a912e434ae 100644 --- a/paddle/fluid/lite/arm/math/CMakeLists.txt +++ b/paddle/fluid/lite/arm/math/CMakeLists.txt @@ -31,5 +31,6 @@ cc_library(math_arm SRCS conv_gemmlike.cc conv_winograd_3x3.cc conv_winograd.cc + split.cc DEPS ${lite_kernel_deps} eigen3) diff --git a/paddle/fluid/lite/arm/math/split.cc b/paddle/fluid/lite/arm/math/split.cc new file mode 100644 index 00000000000..6dd6de6242e --- /dev/null +++ b/paddle/fluid/lite/arm/math/split.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/split.h" +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void split_cpy(const float* din, float* dout, int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* din_ptr = din + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + } + if (remain > 0) { + const float* din_ptr = din + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *din_ptr; + dout_ptr++; + din_ptr++; + } + } +} + +template <> +void split(const float* din, std::vector* dout, + const int axis, const std::vector& in_strides) { + int input_offset = 0; + for (auto out : *dout) { + auto out_dim = out->dims(); + std::vector out_strides(out_dim.size()); + out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; + for (int i = out_dim.size() - 2; i >= 0; --i) { + out_strides[i] = out_strides[i + 1] * out_dim[i]; + } + + float* out_data = out->mutable_data(); + int before = out_strides[0] / out_strides[axis]; + int in_after = in_strides[axis]; + int out_after = out_strides[axis]; + + for (int i = 0; i < before; ++i) { + split_cpy(din + input_offset + i * in_after, out_data + i * out_after, + out_after); + } + input_offset += out_strides[axis]; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/split.h b/paddle/fluid/lite/arm/math/split.h new file mode 100644 index 00000000000..9b5651d81ff --- /dev/null +++ b/paddle/fluid/lite/arm/math/split.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void split_cpy(const T* din, T* dout, int num); + +template +void split(const T* din, std::vector* dout, const int axis, + const std::vector& in_strides); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index c0fa480f094..1cf66b0d266 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -12,6 +12,7 @@ cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) @@ -19,6 +20,7 @@ lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_ lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) +lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) set(arm_kernels fc_compute_arm @@ -29,6 +31,7 @@ set(arm_kernels conv_compute_arm elementwise_add_compute_arm pool_compute_arm + split_compute_arm ) set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") diff --git a/paddle/fluid/lite/kernels/arm/split_compute.cc b/paddle/fluid/lite/kernels/arm/split_compute.cc new file mode 100644 index 00000000000..9da69894592 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/split_compute.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/split_compute.h" +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void SplitCompute::Run() { + auto& param = Param(); + const float* din = param.x->data(); + auto* dout = param.output; + auto in_dim = param.x->dims(); + std::vector in_strides(in_dim.size()); + in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1]; + for (int i = in_dim.size() - 2; i >= 0; --i) { + in_strides[i] = in_strides[i + 1] * in_dim[i]; + } + lite::arm::math::split(din, dout, param.axis, in_strides); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(split, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::SplitCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/split_compute.h b/paddle/fluid/lite/kernels/arm/split_compute.h new file mode 100644 index 00000000000..22701ba0fd9 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/split_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class SplitCompute : public KernelLite { + public: + void Run() override; + + virtual ~SplitCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/split_compute_test.cc b/paddle/fluid/lite/kernels/arm/split_compute_test.cc new file mode 100644 index 00000000000..808a1e2cdb7 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/split_compute_test.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/split_compute.h" +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void splite_resize_out(const lite::Tensor* din, + std::vector* dout, int axis, int num, + const std::vector& sections) { + for (auto out : *dout) delete out; + dout->clear(); + auto in_dims = din->dims(); + int outs_number; + if (num > 0) { + outs_number = num; + } else { + outs_number = sections.size(); + } + for (int i = 0; i < outs_number; i++) { + dout->push_back(new lite::Tensor); + } + + std::vector outs_dims; + outs_dims.reserve(outs_number); + + if (num > 0) { + int out_axis_dim = in_dims[axis] / num; + for (int i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = out_axis_dim; + outs_dims.push_back(dim); + } + } else if (sections.size() > 0) { + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = sections[i]; + outs_dims.push_back(dim); + } + } + + for (int j = 0; j < outs_dims.size(); ++j) { + (*dout)[j]->Resize(outs_dims[j]); + } +} + +template +void split_compute_ref(const operators::SplitParam& param) { + const dtype* din = param.x->mutable_data(); + auto& dout = param.output; + auto in_dim = param.x->dims(); + int axis = param.axis; + std::vector in_strides(in_dim.size()); + in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1]; + for (int i = in_dim.size() - 2; i >= 0; --i) { + in_strides[i] = in_strides[i + 1] * in_dim[i]; + } + + int input_offset = 0; + for (auto out : *dout) { + auto out_dim = out->dims(); + std::vector out_strides(out_dim.size()); + out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; + for (int i = out_dim.size() - 2; i >= 0; --i) { + out_strides[i] = out_strides[i + 1] * out_dim[i]; + } + + dtype* out_data = out->mutable_data(); + int before = out_strides[0] / out_strides[axis]; + int in_after = in_strides[axis]; + int out_after = out_strides[axis]; + + for (int i = 0; i < before; ++i) { + std::memcpy(out_data + i * out_after, din + input_offset + i * in_after, + sizeof(dtype) * out_after); + } + input_offset += out_strides[axis]; + } +} + +TEST(split_arm, init) { + SplitCompute split; + ASSERT_EQ(split.precision(), PRECISION(kFloat)); + ASSERT_EQ(split.target(), TARGET(kARM)); +} + +TEST(split_arm, compute) { + SplitCompute split; + operators::SplitParam param; + + lite::Tensor x; + std::vector output; + std::vector output_ref; + + for (auto n : {1, 3, 4}) { + for (auto c : {1, 3, 4}) { + for (auto h : {1, 3, 4}) { + for (auto w : {1, 3, 4}) { + for (auto axis : {0, 1, 2, 3}) { + for (auto num : {0, 1, 2, 3}) { + for (auto sections : + {std::vector{1, 1, 1}, std::vector{2, 2}, + std::vector{1, 2}}) { + auto x_dim = DDim(std::vector({n, c, h, w})); + x.Resize(x_dim); + if ((num != 0 && x_dim[axis] % num != 0) || + (num == 0 && x_dim[axis] % sections.size() != 0)) + continue; + auto* x_data = x.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i; + } + splite_resize_out(&x, &output, axis, num, sections); + splite_resize_out(&x, &output_ref, axis, num, sections); + param.x = &x; + param.axis = axis; + param.num = num; + param.sections = §ions; + param.output = &output; + split.SetParam(param); + split.Run(); + param.output = &output_ref; + split_compute_ref(param); + for (int i = 0; i < output.size(); i++) { + float* output_data = output[i]->mutable_data(); + float* output_ref_data = output_ref[i]->mutable_data(); + for (int j = 0; j < output[i]->dims().production(); j++) { + EXPECT_NEAR(output_data[j], output_ref_data[j], 1e-5); + } + } + } + } + } + } + } + } + } +} + +TEST(split, retrive_op) { + auto split = + KernelRegistry::Global().Create("split"); + ASSERT_FALSE(split.empty()); + ASSERT_TRUE(split.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(split, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 4230abfae75..9a90666420e 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -19,6 +19,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) +cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS}) set(ops_lite conv_op_lite @@ -38,6 +39,7 @@ set(ops_lite activation_ops_lite dropout_op_lite concat_op_lite + split_op_lite PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index cd87a9d2d39..eee0d90dba2 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -174,6 +174,15 @@ struct DropoutParam { std::string dropout_implementation{"downgrade_in_infer"}; }; +// For Split op +struct SplitParam { + lite::Tensor* x{}; + std::vector* output{}; + int axis{-1}; + int num{0}; + std::vector* sections; +}; + /// ----------------------- element wise operators ---------------------- struct ElementwiseParam { const lite::Tensor* X{}; diff --git a/paddle/fluid/lite/operators/split_op.cc b/paddle/fluid/lite/operators/split_op.cc new file mode 100644 index 00000000000..c788e9cf954 --- /dev/null +++ b/paddle/fluid/lite/operators/split_op.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/split_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SplitOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + CHECK_OR_FALSE(param_.axis >= -static_cast(x_rank) && + param_.axis < static_cast(x_rank)); + return true; +} + +bool SplitOp::InferShape() const { + const auto &outs = param_.output; + auto in_dims = param_.x.dims(); + int axis = param_.axis; + int num = param_.num; + const auto §ions = param_.sections; + + const int outs_number = outs.size(); + std::vector outs_dims; + outs_dims.reserve(outs_number); + + if (num > 0) { + int out_axis_dim = in_dims[axis] / num; + for (int i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = out_axis_dim; + outs_dims.push_back(dim); + } + } else if (sections.size() > 0) { + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = sections[i]; + outs_dims.push_back(dim); + } + } + + for (int j = 0; j < outs_dims.size(); ++j) { + outs[j]->Resize(outs_dims[j]); + } + + return true; +} + +bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.axis = opdesc.GetAttr("axis"); + param_.num = opdesc.GetAttr("num"); + param_.sections = opdesc.GetAttr>("sections"); + param_.x = const_cast( + &scope->FindVar(opdesc.Input("X").front())->Get()); + auto outs = op_desc.Output("Out"); + for (auto var : outs) { + param_.output.push_back(scope->FindVar(var)->GetMutable()); + } + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); diff --git a/paddle/fluid/lite/operators/split_op.h b/paddle/fluid/lite/operators/split_op.h new file mode 100644 index 00000000000..177c44171e6 --- /dev/null +++ b/paddle/fluid/lite/operators/split_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SoftmaxOp : public OpLite { + public: + SplitOp() {} + explicit SplitOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "split"; } + + private: + mutable SplitParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab