diff --git a/paddle/fluid/lite/arm/math/CMakeLists.txt b/paddle/fluid/lite/arm/math/CMakeLists.txt index a20b5fa842f37ac7b462b81f77dc7b6340db4bd3..883e7bc4609b09dcea485eb85607fe7e8f2136cf 100644 --- a/paddle/fluid/lite/arm/math/CMakeLists.txt +++ b/paddle/fluid/lite/arm/math/CMakeLists.txt @@ -14,6 +14,7 @@ cc_library(math_arm SRCS scale.cc pooling.cc elementwise.cc + concat.cc sgemv.cc type_trans.cpp conv_impl.cc diff --git a/paddle/fluid/lite/arm/math/concat.cc b/paddle/fluid/lite/arm/math/concat.cc new file mode 100644 index 0000000000000000000000000000000000000000..fd375ab0e7f7700b31013fa55d73ddb732fd2e97 --- /dev/null +++ b/paddle/fluid/lite/arm/math/concat.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/concat.h" +#include +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void concat_func(const std::vector &input, const int axis, + lite::Tensor *output) { + size_t num = input.size(); + int rows = 1; + auto dim_0 = input[0]->dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int t_cols = input[i]->numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + + // computation + for (int k = 0; k < out_rows; ++k) { + float *dst_ptr = output->mutable_data() + k * out_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + const float *src_prt = input[j]->data() + k * col_len; + std::memcpy(dst_ptr + col_idx, src_prt, sizeof(float) * col_len); + col_idx += col_len; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/concat.h b/paddle/fluid/lite/arm/math/concat.h new file mode 100644 index 0000000000000000000000000000000000000000..bc67523a494559011e79b9d8c687b8521b5b669b --- /dev/null +++ b/paddle/fluid/lite/arm/math/concat.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void concat_func(const std::vector &input, const int axis, + lite::Tensor *output); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index 7540d7e012df27c94de6c6398686310c4d59afad..d072f23776c5507658fd2cbe8b440b5fdcebecc9 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -14,6 +14,7 @@ cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_ cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(concat_compute_arm SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(dropout_compute_arm SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) @@ -26,6 +27,7 @@ lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test. lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) +lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm) lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm) set(arm_kernels @@ -39,6 +41,7 @@ set(arm_kernels elementwise_add_compute_arm pool_compute_arm split_compute_arm + concat_compute_arm dropout_compute_arm ) diff --git a/paddle/fluid/lite/kernels/arm/concat_compute.cc b/paddle/fluid/lite/kernels/arm/concat_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..70adb8fc33ec0ab9c925f77748536f3372632b55 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/concat_compute.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/concat_compute.h" +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +std::vector stride_numel(const DDim& ddim) { + std::vector strides(ddim.size()); + strides[ddim.size() - 1] = ddim[ddim.size() - 1]; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i]; + } + return strides; +} + +void ConcatCompute::Run() { + auto& param = Param(); + std::vector inputs = param.x; + auto* out = param.output; + int axis = param.axis; + out->mutable_data(); + + /// Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && inputs.size() < 10) { + size_t output_offset = 0; + for (auto* in : inputs) { + auto in_stride = stride_numel(in->dims()); + auto out_stride = stride_numel(out->dims()); + void* dst = out->mutable_data() + output_offset; + const void* src = in->data(); +#if 0 + LOG(INFO) << "out_stride.size():" << out_stride.size(); + LOG(INFO) << "out_stride[0]" << out_stride[0]; + for (int i=0; i < out_stride.size(); ++i) { + LOG(INFO) << "out_stride[" << i << "]:" << out_stride[i]; + } + LOG(INFO) << "in_stride.size():" << in_stride.size(); + for (int i=0; i < in_stride.size(); ++i) { + LOG(INFO) << "in_stride[" << i << "]:" << in_stride[i]; + } +#endif + // src and dst tensor should have the same dims size. + CHECK(in_stride.size() == out_stride.size()); + std::memcpy(dst, src, sizeof(float) * in_stride[0]); + output_offset += in_stride[0]; + } + } else { + std::vector inputs_concat(inputs.size()); + for (int j = 0; j < inputs.size(); ++j) { + inputs_concat[j] = inputs[j]; + } + lite::arm::math::concat_func(inputs_concat, axis, out); + } + return; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(concat, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::ConcatCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/concat_compute.h b/paddle/fluid/lite/kernels/arm/concat_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..2e1ca89841fdcfef869143a9ac3833842dda527e --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/concat_compute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/operators/concat_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ConcatCompute : public KernelLite { + public: + using param_t = operators::ConcatParam; + + void Run() override; + + virtual ~ConcatCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/concat_compute_test.cc b/paddle/fluid/lite/kernels/arm/concat_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..664f4ed116735ceb2d24be2ead887f7680f29230 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/concat_compute_test.cc @@ -0,0 +1,235 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/concat_compute.h" +#include +#include +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/lite_tensor.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +bool infer_shape(const operators::ConcatParam& param) { + std::vector input_dims; + for (auto p : param.x) { + input_dims.push_back(p->dims()); + } + size_t axis = static_cast(param.axis); + const size_t n = input_dims.size(); + CHECK_GT_OR_FALSE(n, 0); + auto& out_dims = input_dims[0]; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + out_dims[axis] += input_dims[i][j]; + } else { + CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); + } + } + } + if (out_dims[axis] < 0) { + out_dims[axis] = -1; + } + // Set output dims + param.output->Resize(lite::DDim(out_dims)); + return true; +} + +void concat_compute_ref(const operators::ConcatParam& param) { + std::vector input = param.x; + int axis = param.axis; + infer_shape(param); + + lite::Tensor* output = param.output; + int num = input.size(); + int rows = 1; + auto dim_0 = input[0]->dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int input_i_numel = input[i]->dims().size() == 0 ? 0 : 1; + for (int didx = 0; didx < input[i]->dims().size(); ++didx) { + input_i_numel *= input[i]->dims()[didx]; + } + int t_cols = input_i_numel / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + + // computation + auto output_data = output->mutable_data(); + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = input[j]->data(); + for (int k = 0; k < out_rows; ++k) { + memcpy(output_data + k * out_cols + col_idx, input_data + k * col_len, + sizeof(float) * col_len); + } + col_idx += col_len; + } +} + +TEST(concat_arm, init) { + ConcatCompute concat; + ASSERT_EQ(concat.precision(), PRECISION(kFloat)); + ASSERT_EQ(concat.target(), TARGET(kARM)); +} + +TEST(concat_arm, compute_input_single) { + ConcatCompute concat; + operators::ConcatParam param; + + LOG(INFO) << "test concat start"; + lite::Tensor output; + lite::Tensor output_ref; + lite::Tensor tensorA; + DDimLite ddimA({10, 4, 3, 2}); + tensorA.Resize(ddimA); + + for (int i = 0; i < ddimA.data()[0] * ddimA.data()[1] * ddimA.data()[2] * + ddimA.data()[3]; + i++) { + tensorA.mutable_data()[i] = i; + } + + param.x.push_back(&tensorA); + for (int cur_axis : {0, 1}) { + param.output = &output; + param.axis = cur_axis; + CHECK(infer_shape(param)); + concat.SetParam(param); + LOG(INFO) << "test concat start cur_axis:" << cur_axis; + + concat.Run(); + LOG(INFO) << "concat.Run end"; + param.output = &output_ref; + LOG(INFO) << "concat_compute_ref start"; + concat_compute_ref(param); + LOG(INFO) << "concat_compute_ref end"; + + auto* output_data = output.data(); + auto* output_ref_data = output_ref.data(); + for (int i = 0; i < (ddimA.data()[0]) * ddimA.data()[1] * ddimA.data()[2] * + ddimA.data()[3]; + i++) { + // LOG(INFO) << "output[" << i << "]:" << output_data[i] << " + // output_ref_data[" << i << "]:" << output_ref_data[i]; + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } +} + +TEST(concat_arm, compute_input_multi) { + ConcatCompute concat; + operators::ConcatParam param; + + LOG(INFO) << "test concat start"; + // init param + // x: tensorA, tensorB, tensorC, tensorD + // axis: 0 + lite::Tensor output; + lite::Tensor output_ref; + lite::Tensor tensorA; + lite::Tensor tensorB; + lite::Tensor tensorC; + lite::Tensor tensorD; + + DDimLite ddimA({10, 4, 3, 2}); + DDimLite ddimB({20, 4, 3, 2}); + DDimLite ddimC({30, 4, 3, 2}); + DDimLite ddimD({40, 4, 3, 2}); + + tensorA.Resize(ddimA); + tensorB.Resize(ddimB); + tensorC.Resize(ddimC); + tensorD.Resize(ddimD); + + for (int i = 0; i < ddimA.data()[0] * ddimA.data()[1] * ddimA.data()[2] * + ddimA.data()[3]; + i++) { + tensorA.mutable_data()[i] = i; + } + for (int i = 0; i < ddimB.data()[0] * ddimB.data()[1] * ddimB.data()[2] * + ddimB.data()[3]; + i++) { + tensorB.mutable_data()[i] = i + 1; + } + for (int i = 0; i < ddimC.data()[0] * ddimC.data()[1] * ddimC.data()[2] * + ddimC.data()[3]; + i++) { + tensorC.mutable_data()[i] = i + 2; + } + for (int i = 0; i < ddimD.data()[0] * ddimD.data()[1] * ddimD.data()[2] * + ddimD.data()[3]; + i++) { + tensorD.mutable_data()[i] = i + 3; + } + + param.x.push_back(&tensorA); + param.x.push_back(&tensorB); + param.x.push_back(&tensorC); + param.x.push_back(&tensorD); + for (int cur_axis : {0}) { + param.output = &output; + param.axis = cur_axis; + CHECK(infer_shape(param)); + concat.SetParam(param); + LOG(INFO) << "test concat start cur_axis:" << cur_axis; + + concat.Run(); + LOG(INFO) << "concat.Run end"; + param.output = &output_ref; + LOG(INFO) << "concat_compute_ref start"; + concat_compute_ref(param); + LOG(INFO) << "concat_compute_ref end"; + + auto* output_data = output.data(); + auto* output_ref_data = output_ref.data(); + int elem_num = (ddimA.data()[0] + ddimB.data()[0] + ddimC.data()[0] + + ddimD.data()[0]) * + ddimA.data()[1] * ddimA.data()[2] * ddimA.data()[3]; + for (int i = 0; i < elem_num; i++) { + // LOG(INFO) << "output[" << i << "]:" << output_data[i] << " + // output_ref_data[" << i << "]:" << output_ref_data[i]; + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } +} + +TEST(concat, retrive_op) { + auto concat = + KernelRegistry::Global().Create( + "concat"); + ASSERT_FALSE(concat.empty()); + ASSERT_TRUE(concat.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(concat, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/use_kernels.h b/paddle/fluid/lite/kernels/arm/use_kernels.h index 1f93a81aa94f09f8330aa385840adec559d7161d..1a6583f3f570e688080b1bb1a96217c25ca4bcc9 100644 --- a/paddle/fluid/lite/kernels/arm/use_kernels.h +++ b/paddle/fluid/lite/kernels/arm/use_kernels.h @@ -19,6 +19,7 @@ USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(concat, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);