diff --git a/lite/api/_paddle_use_kernels.h b/lite/api/_paddle_use_kernels.h index 7e2ebf1011ba0d3b1d442734a7bd793226091397..fe522592370c778aacb6fc4b7c4668ef80e62728 100644 --- a/lite/api/_paddle_use_kernels.h +++ b/lite/api/_paddle_use_kernels.h @@ -106,6 +106,8 @@ USE_LITE_KERNEL(generate_proposals, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(squeeze, kARM, kFloat, kNCHW, def) // for x2paddle USE_LITE_KERNEL(squeeze2, kARM, kFloat, kNCHW, def) // for x2paddle USE_LITE_KERNEL(expand, kARM, kFloat, kNCHW, def) // for x2paddle +USE_LITE_KERNEL(reduce_mean, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(stack, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(roi_align, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8); diff --git a/lite/api/_paddle_use_ops.h b/lite/api/_paddle_use_ops.h index e1a16f2f5697960e1076f672c3a9e984d7d194bf..c41e395408d30bcf699f88ca5c8be260a3cc0c30 100644 --- a/lite/api/_paddle_use_ops.h +++ b/lite/api/_paddle_use_ops.h @@ -69,6 +69,8 @@ USE_LITE_OP(shuffle_channel) USE_LITE_OP(yolo_box) USE_LITE_OP(bilinear_interp) USE_LITE_OP(nearest_interp) +USE_LITE_OP(reduce_mean) +USE_LITE_OP(stack) USE_LITE_OP(assign); USE_LITE_OP(crop) diff --git a/lite/arm/math/CMakeLists.txt b/lite/arm/math/CMakeLists.txt index deee5a64bc1a0769f27750b2b06f4cef4688362f..e228259da39f8f100dab0bc1d97732a03d88b431 100644 --- a/lite/arm/math/CMakeLists.txt +++ b/lite/arm/math/CMakeLists.txt @@ -103,6 +103,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR) sequence_pool.cc sequence_expand.cc slice.cc + reduce_mean.cc + stack.cc affine_channel.cc anchor_generator.cc DEPS ${lite_kernel_deps}) diff --git a/lite/arm/math/funcs.h b/lite/arm/math/funcs.h index 64fc10b4fff0a825dce49dee4c499c7ddf25ac24..9b4a1ca9726eeab1e97088d4a930e6d7fbf07b52 100644 --- a/lite/arm/math/funcs.h +++ b/lite/arm/math/funcs.h @@ -48,6 +48,7 @@ #include "lite/arm/math/power.h" #include "lite/arm/math/prior_box.h" #include "lite/arm/math/reduce_max.h" +#include "lite/arm/math/reduce_mean.h" #include "lite/arm/math/scale.h" #include "lite/arm/math/sequence_expand.h" #include "lite/arm/math/sequence_pool.h" @@ -58,6 +59,7 @@ #include "lite/arm/math/slice.h" #include "lite/arm/math/softmax.h" #include "lite/arm/math/split.h" +#include "lite/arm/math/stack.h" #include "lite/arm/math/topk.h" #include "lite/arm/math/yolo_box.h" namespace paddle { diff --git a/lite/arm/math/reduce_mean.cc b/lite/arm/math/reduce_mean.cc new file mode 100644 index 0000000000000000000000000000000000000000..786f8320ee2fd154670f2d2bc1c368d66a021066 --- /dev/null +++ b/lite/arm/math/reduce_mean.cc @@ -0,0 +1,204 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/arm/math/reduce_mean.h" +#include "lite/arm/math/funcs.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void reduce_mean_n(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = channel_in * hw_size; + int data_index, src_index, src_index0; + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = c * hw_size + h * width_in + w; + dst[data_index] = 0.0; + for (int n = 1; n < num_in; ++n) { + src_index = n * chw_size + data_index; + dst[data_index] += static_cast(src[src_index]) / num_in; + } + } + } + } +} + +template <> +void reduce_mean_c(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = hw_size * channel_in; + int data_index, src_index0, src_index; + for (int n = 0; n < num_in; ++n) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = n * hw_size + h * width_in + w; + src_index0 = n * chw_size + h * width_in + w; + dst[data_index] = 0.0; + for (int c = 1; c < channel_in; ++c) { + src_index = src_index0 + c * hw_size; + dst[data_index] += static_cast(src[src_index]) / channel_in; + } + } + } + } +} + +template <> +void reduce_mean_h(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int cw_size = channel_in * width_in; + int chw_size = cw_size * height_in; + int hw_size = height_in * width_in; + int data_index, src_index, src_index0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int w = 0; w < width_in; ++w) { + data_index = n * cw_size + c * width_in + w; + src_index0 = n * chw_size + c * hw_size + w; + dst[data_index] = 0.0; + for (int h = 1; h < height_in; ++h) { + src_index = src_index0 + h * width_in; + dst[data_index] += static_cast(src[src_index]) / height_in; + } + } + } + } +} + +template <> +void reduce_mean_w(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int ch_size = channel_in * height_in; + int hw_size = height_in * width_in; + int chw_size = ch_size * width_in; + int data_index = 0; + int src_index0 = 0; + int src_index = 0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + data_index = n * ch_size + c * height_in + h; + src_index0 = n * chw_size + c * hw_size + h * width_in; + dst[data_index] = 0.0; + for (int w = 1; w < width_in; ++w) { + src_index = src_index0 + w; + dst[data_index] += static_cast(src[src_index]) / width_in; + } + } + } + } +} + +template <> +void reduce_mean_all(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + float mean = 0.0; + int src_index; + int n_id, c_id; + int all = num_in * channel_in * height_in * width_in; + for (int n = 0; n < num_in; ++n) { + n_id = n * channel_in * height_in * width_in; + for (int c = 0; c < channel_in; ++c) { + c_id = c * height_in * width_in; + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + src_index = n_id + c_id + h * width_in + w; + mean = src[src_index] / all; + } + } + } + } + dst[0] = mean; +} + +template <> +void reduce_mean_nc(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce n first. + DDimLite ddimA({1, channel_in, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + float* tmp_out = tensor_tmp.mutable_data(); + reduce_mean_n(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_mean_c(tmp_out, dst, 1, channel_in, height_in, width_in); +} + +template <> +void reduce_mean_ch(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce c first + DDimLite ddimA({num_in, 1, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + float* tmp_out = tensor_tmp.mutable_data(); + reduce_mean_c(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_mean_h(tmp_out, dst, num_in, 1, height_in, width_in); +} + +template <> +void reduce_mean_hw(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce h first + DDimLite ddimA({num_in, channel_in, 1, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + float* tmp_out = tensor_tmp.mutable_data(); + reduce_mean_h(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_mean_w(tmp_out, dst, num_in, channel_in, 1, width_in); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/reduce_mean.h b/lite/arm/math/reduce_mean.h new file mode 100644 index 0000000000000000000000000000000000000000..277ed209c058b5b4be76ce18a00683610e6afb7a --- /dev/null +++ b/lite/arm/math/reduce_mean.h @@ -0,0 +1,89 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void reduce_mean_n(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_mean_c(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_mean_h(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_mean_w(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_mean_nc(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_mean_ch(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_mean_hw(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_mean_all(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/stack.cc b/lite/arm/math/stack.cc new file mode 100644 index 0000000000000000000000000000000000000000..1775e6048e39ce6ab33632a2e4a8991112b12b7f --- /dev/null +++ b/lite/arm/math/stack.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/stack.h" +#include +#include +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void stack(std::vector x, lite::Tensor *y, int axis) { + if (axis < 0) axis += (x[0]->dims().size() + 1); + int n = x.size(); + auto *y_data = y->mutable_data(); + std::vector x_datas(n); + for (int i = 0; i < n; i++) x_datas[i] = x[i]->data(); + + int pre = 1, post = 1; + auto &dim = x[0]->dims(); + for (auto i = 0; i < axis; ++i) pre *= dim[i]; + for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; + + auto x_data_arr = x_datas.data(); + + size_t x_offset = 0; + size_t y_offset = 0; + for (int i = 0; i < pre; i++) { + for (int j = 0; j < n; j++) { + std::memcpy( + y_data + y_offset, x_data_arr[j] + x_offset, post * sizeof(float)); + y_offset += post; + } + x_offset += post; + } +} + +} /* namespace math */ +} /* namespace arm */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/arm/math/stack.h b/lite/arm/math/stack.h new file mode 100644 index 0000000000000000000000000000000000000000..2000b3da60870778c280e313b5372f0963027820 --- /dev/null +++ b/lite/arm/math/stack.h @@ -0,0 +1,30 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void stack(std::vector x, lite::Tensor* out, int axis); + +} /* namespace math */ +} /* namespace arm */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index bfd59f134d08fb7db5e5623daf7f124ce158fb6c..19d3f3fd762267c1ef21e639ce62824531f85cef 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -45,6 +45,8 @@ add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${li add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(assign_compute_arm ARM basic SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(anchor_generator_compute_arm ARM basic SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/reduce_mean_compute.cc b/lite/kernels/arm/reduce_mean_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..989550b8da5aca30141a74c5a3544243f1fcd6d4 --- /dev/null +++ b/lite/kernels/arm/reduce_mean_compute.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/reduce_mean_compute.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ReduceMeanCompute::Run() { + auto& param = Param(); + const float* input = param.X->data(); + auto x_dims = param.X->dims(); + int x_rank = x_dims.size(); + float* output = param.Out->mutable_data(); + bool keep_dim = param.keep_dim; + auto dim = param.dim; + + if (!dim.empty()) { + for (int i = 0; i < dim.size(); i++) { + if (dim[i] < 0) { + dim[i] += x_rank; + } + } + } + int n_in = x_dims[0]; + int c_in = x_dims[1]; + int h_in = x_dims[2]; + int w_in = x_dims[3]; + if (dim.size() == 0) { + lite::arm::math::reduce_mean_all(input, output, n_in, c_in, h_in, w_in); + } else if (dim.size() == 1) { + switch (dim[0]) { + case 0: + lite::arm::math::reduce_mean_n(input, output, n_in, c_in, h_in, w_in); + break; + case 1: + lite::arm::math::reduce_mean_c(input, output, n_in, c_in, h_in, w_in); + break; + case 2: + lite::arm::math::reduce_mean_h(input, output, n_in, c_in, h_in, w_in); + break; + case 3: + lite::arm::math::reduce_mean_w(input, output, n_in, c_in, h_in, w_in); + break; + default: + LOG(FATAL) << "error!!!"; + } + } else if (dim.size() == 2) { + if (dim[0] == 0 && dim[1] == 1) { + lite::arm::math::reduce_mean_nc(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 1 && dim[1] == 2) { + lite::arm::math::reduce_mean_ch(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 2 && dim[1] == 3) { + lite::arm::math::reduce_mean_hw(input, output, n_in, c_in, h_in, w_in); + } else { + LOG(FATAL) << "invalid dim!!"; + } + } else { + LOG(FATAL) << "dim's size over than 2, which is not supported now!!"; + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(reduce_mean, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ReduceMeanCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/reduce_mean_compute.h b/lite/kernels/arm/reduce_mean_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..5f4bbd46d54ee9a1f77a52f22e60be0409648ffb --- /dev/null +++ b/lite/kernels/arm/reduce_mean_compute.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/arm/math/type_trans.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ReduceMeanCompute : public KernelLite { + public: + void Run() override; + + virtual ~ReduceMeanCompute() = default; + + private: +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/stack_compute.cc b/lite/kernels/arm/stack_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb9fa85211cb57644bac22810e647fae997ada01 --- /dev/null +++ b/lite/kernels/arm/stack_compute.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/stack_compute.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void StackCompute::Run() { + auto& param = Param(); + std::vector x = param.X; + lite::Tensor* out = param.Out; + int axis = param.axis; + + lite::arm::math::stack(x, out, axis); +} + +} /* namespace arm */ +} /* namespace kernels */ +} /* namespace lite */ +} /* namespace paddle */ + +REGISTER_LITE_KERNEL( + stack, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::StackCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/stack_compute.h b/lite/kernels/arm/stack_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..2dbb1d7a212297bc2f1bf1d9a73c7a5fd6a333b8 --- /dev/null +++ b/lite/kernels/arm/stack_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class StackCompute : public KernelLite { + public: + void Run() override; + + virtual ~StackCompute() = default; +}; + +} /* namespace arm */ +} /* namespace kernels */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 14a1dfcfaf597e07888abf5d7295ac97e46079e8..c5cfedb84755d8c81f2ada8d20db814d38852386 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -59,6 +59,8 @@ add_operator(shape_op_lite basic SRCS shape_op.cc DEPS ${op_DEPS}) add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_DEPS}) add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS}) add_operator(im2sequence_op basic SRCS im2sequence_op.cc DEPS ${op_DEPS}) +add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS}) +add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS}) add_operator(cast_op_lite basic SRCS cast_op.cc DEPS ${op_DEPS}) add_operator(assign_op basic SRCS assign_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index eab5fb2bc12cecb2599535c010aba6738a7dfbec..6b2c7349d4ac907de89c45b25f0b644454203c22 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -121,6 +121,23 @@ struct MulGradParam { int y_num_col_dims{1}; }; +// For ReduceMean Op +struct ReduceMeanParam { + lite::Tensor* X{}; + lite::Tensor* Out{}; + + std::vector dim; + bool keep_dim{false}; +}; + +// For Stack Op +struct StackParam { + std::vector X; + lite::Tensor* Out{}; + + int axis{0}; +}; + // For Power Op struct PowerParam { const lite::Tensor* X{}; diff --git a/lite/operators/reduce_mean_op.cc b/lite/operators/reduce_mean_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bce31c315c22e93d7758a05ecf2ace0668dd0cc1 --- /dev/null +++ b/lite/operators/reduce_mean_op.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/reduce_mean_op.h" +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool ReduceMeanOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + auto dims = param_.dim; + auto x_dims = param_.X->dims(); + int x_rank = x_dims.size(); + if (dims.size() != 0) { + for (int i = 0; i < dims.size(); i++) { + if (dims[i] < 0) { + dims[i] = x_rank + dims[i]; + } + CHECK_OR_FALSE(dims[i] <= x_rank && dims[i] >= -x_rank); + } + } + return true; +} + +bool ReduceMeanOp::InferShape() const { + auto dims = param_.dim; + auto x_dims = param_.X->dims(); + bool reduce_all = false; + bool keep_dim = param_.keep_dim; + auto x_rank = x_dims.size(); + if (dims.size() != 0) { + for (int i = 0; i < dims.size(); i++) { + if (dims[i] < 0) { + dims[i] = x_rank + dims[i]; + } + } + } + sort(dims.begin(), dims.end()); + if (dims.size() == 0) { + reduce_all = true; + } + std::vector out_dims; + if (reduce_all) { + if (keep_dim) { + out_dims.push_back(x_rank); + out_dims.push_back(1); + } else { + out_dims.push_back(1); + } + } else { + for (int i = 0; i < x_dims.size(); i++) { + out_dims.push_back(x_dims[i]); + } + if (keep_dim) { + for (size_t i = 0; i < dims.size(); ++i) { + out_dims[dims[i]] = 1; + } + } else { + const int64_t kDelFlag = -2; + for (size_t i = 0; i < dims.size(); ++i) { + out_dims[dims[i]] = kDelFlag; + } + out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag), + out_dims.end()); + } + param_.Out->Resize(DDim(out_dims)); + if (dims[0] != 0) { + // Only pass LoD when not reducing on the first dim. + *param_.Out->mutable_lod() = param_.X->lod(); + } + } + return true; +} + +bool ReduceMeanOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.X = const_cast( + &scope->FindVar(opdesc.Input("X").front())->Get()); + param_.Out = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + param_.dim = opdesc.GetAttr>("dim"); + if (opdesc.HasAttr("keep_dim")) { + param_.keep_dim = opdesc.GetAttr("keep_dim"); + } else { + param_.keep_dim = false; + } + CHECK(param_.X); + CHECK(param_.Out); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(reduce_mean, paddle::lite::operators::ReduceMeanOp); diff --git a/lite/operators/reduce_mean_op.h b/lite/operators/reduce_mean_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e701a1132aa1260b5f169f89dec546a0d80fc916 --- /dev/null +++ b/lite/operators/reduce_mean_op.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ReduceMeanOp : public OpLite { + public: + ReduceMeanOp() {} + explicit ReduceMeanOp(const std::string &op_type) : OpLite(op_type) {} + bool CheckShape() const override; + bool InferShape() const override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "reduce_mean"; } + + private: + mutable ReduceMeanParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/stack_op.cc b/lite/operators/stack_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc8537117ec7505d2f038bfecd289ffbe83843b4 --- /dev/null +++ b/lite/operators/stack_op.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/stack_op.h" +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool StackOp::CheckShape() const { + auto input = param_.X; + for (auto x : input) { + CHECK_OR_FALSE(x); + } + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool StackOp::InferShape() const { + auto input = param_.X; + auto input_dims = input[0]->dims(); + int axis = param_.axis; + int rank = input_dims.size(); + if (axis < 0) axis += (rank + 1); + auto vec = input_dims.Vectorize(); + vec.insert(vec.begin() + axis, input.size()); + param_.Out->Resize(vec); + return true; +} + +bool StackOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + auto X = op_desc.Input("X"); + auto Out = op_desc.Output("Out").front(); + for (auto var : X) { + param_.X.emplace_back(scope->FindVar(var)->GetMutable()); + } + param_.Out = scope->FindVar(Out)->GetMutable(); + param_.axis = op_desc.GetAttr("axis"); + return true; +} + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ + +REGISTER_LITE_OP(stack, paddle::lite::operators::StackOp); diff --git a/lite/operators/stack_op.h b/lite/operators/stack_op.h new file mode 100644 index 0000000000000000000000000000000000000000..068d905338bde892b44630c64d3ec43771614f2a --- /dev/null +++ b/lite/operators/stack_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class StackOp : public OpLite { + public: + StackOp() {} + + explicit StackOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "stack"; } + + private: + mutable StackParam param_; +}; + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index b4b3a7d01b28e2126ae9f8ad3f0726f5afaa357a..a78424dede7a54f22de0a8c007c02c89f6982733 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -43,6 +43,8 @@ endif() lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/reduce_mean_compute_test.cc b/lite/tests/kernels/reduce_mean_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e8733ef3eaed384d13a10d044033f527fc108874 --- /dev/null +++ b/lite/tests/kernels/reduce_mean_compute_test.cc @@ -0,0 +1,346 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +void reduce_mean_n(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = channel_in * hw_size; + int data_index, src_index, src_index0; + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = c * hw_size + h * width_in + w; + dst[data_index] = 0.0; + for (int n = 1; n < num_in; ++n) { + src_index = n * chw_size + data_index; + dst[data_index] += static_cast(src[src_index]) / num_in; + } + } + } + } +} + +void reduce_mean_c(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = hw_size * channel_in; + int data_index, src_index0, src_index; + for (int n = 0; n < num_in; ++n) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = n * hw_size + h * width_in + w; + src_index0 = n * chw_size + h * width_in + w; + dst[data_index] = 0.0; + for (int c = 1; c < channel_in; ++c) { + src_index = src_index0 + c * hw_size; + dst[data_index] += static_cast(src[src_index]) / channel_in; + } + } + } + } +} + +void reduce_mean_h(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int cw_size = channel_in * width_in; + int chw_size = cw_size * height_in; + int hw_size = height_in * width_in; + int data_index, src_index, src_index0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int w = 0; w < width_in; ++w) { + data_index = n * cw_size + c * width_in + w; + src_index0 = n * chw_size + c * hw_size + w; + dst[data_index] = 0.0; + for (int h = 1; h < height_in; ++h) { + src_index = src_index0 + h * width_in; + dst[data_index] += static_cast(src[src_index]) / height_in; + } + } + } + } +} + +void reduce_mean_w(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int ch_size = channel_in * height_in; + int hw_size = height_in * width_in; + int chw_size = ch_size * width_in; + int data_index = 0; + int src_index0 = 0; + int src_index = 0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + data_index = n * ch_size + c * height_in + h; + src_index0 = n * chw_size + c * hw_size + h * width_in; + dst[data_index] = 0.0; + for (int w = 1; w < width_in; ++w) { + src_index = src_index0 + w; + dst[data_index] += static_cast(src[src_index]) / width_in; + } + } + } + } +} + +void reduce_mean_all(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + float mean = 0.0; + int src_index; + int n_id, c_id; + int all = num_in * channel_in * height_in * width_in; + for (int n = 0; n < num_in; ++n) { + n_id = n * channel_in * height_in * width_in; + for (int c = 0; c < channel_in; ++c) { + c_id = c * height_in * width_in; + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + src_index = n_id + c_id + h * width_in + w; + mean = src[src_index] / all; + } + } + } + } + dst[0] = mean; +} + +void reduce_mean_nc(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce n first. + DDimLite ddimA({1, channel_in, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + float* tmp_out = tensor_tmp.mutable_data(); + reduce_mean_n(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_mean_c(tmp_out, dst, 1, channel_in, height_in, width_in); +} + +void reduce_mean_ch(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce c first + DDimLite ddimA({num_in, 1, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + float* tmp_out = tensor_tmp.mutable_data(); + reduce_mean_c(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_mean_h(tmp_out, dst, num_in, 1, height_in, width_in); +} + +void reduce_mean_hw(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce h first + DDimLite ddimA({num_in, channel_in, 1, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + float* tmp_out = tensor_tmp.mutable_data(); + reduce_mean_h(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_mean_w(tmp_out, dst, num_in, channel_in, 1, width_in); +} + +class ReduceMeanComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string input_ = "x"; + std::string output_ = "out"; + std::vector dim_{0}; + DDim x_dims_{{3, 2, 3, 4}}; + bool keep_dim_ = false; + bool reduce_all_ = false; + + public: + ReduceMeanComputeTester(const Place& place, + const std::string& alias, + std::vector dim, + bool keep_dim, + DDim x_dims) + : TestCase(place, alias), + dim_(dim), + keep_dim_(keep_dim), + x_dims_(x_dims) {} + + void RunBaseline(Scope* scope) override { + auto* x = scope->FindMutableTensor(input_); + const auto* x_data = x->data(); + auto* out = scope->NewTensor(output_); + auto x_rank = x_dims_.size(); + if (!dim_.empty()) { + for (int i = 0; i < dim_.size(); i++) { + if (dim_[i] < 0) { + dim_[i] += x_rank; + } + } + } + + sort(dim_.begin(), dim_.end()); + if (dim_.size() == 0) { + reduce_all_ = true; + } + std::vector out_dims; + if (reduce_all_) { + if (keep_dim_) { + out_dims.push_back(x_rank); + out_dims.push_back(1); + } else { + out_dims.push_back(1); + } + } else { + for (int i = 0; i < x_dims_.size(); i++) { + out_dims.push_back(x_dims_[i]); + } + if (keep_dim_) { + for (size_t i = 0; i < dim_.size(); ++i) { + out_dims[dim_[i]] = 1L; + } + } else { + int64_t kDelFlag = -2; + for (size_t i = 0; i < dim_.size(); ++i) { + out_dims[dim_[i]] = kDelFlag; + } + out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag), + out_dims.end()); + } + out->Resize(DDim(out_dims)); + } + + auto* out_data = out->mutable_data(); + int in_n = x_dims_[0]; + int in_c = x_dims_[1]; + int in_h = x_dims_[2]; + int in_w = x_dims_[3]; + + if (dim_.size() == 0) { + reduce_mean_all(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_.size() == 1) { + switch (dim_[0]) { + case 0: + reduce_mean_n(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 1: + reduce_mean_c(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 2: + reduce_mean_h(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 3: + reduce_mean_w(x_data, out_data, in_n, in_c, in_h, in_w); + break; + default: + LOG(FATAL) << "error!!!"; + } + } else if (dim_.size() == 2) { + if (dim_[0] == 0 && dim_[1] == 1) { + reduce_mean_nc(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_[0] == 1 && dim_[1] == 2) { + reduce_mean_ch(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_[0] == 2 && dim_[1] == 3) { + reduce_mean_hw(x_data, out_data, in_n, in_c, in_h, in_w); + } else { + LOG(FATAL) << "invalid dims_!!"; + } + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("reduce_mean"); + op_desc->SetInput("X", {input_}); + op_desc->SetOutput("Out", {output_}); + op_desc->SetAttr("dim", dim_); + op_desc->SetAttr("keep_dim", keep_dim_); + } + + void PrepareData() override { + std::vector data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + data[i] = i * 1.0; + } + SetCommonTensor(input_, x_dims_, data.data()); + } +}; + +void test_reduce_mean(Place place) { + std::vector> reduce_dim{ + {0}, {1}, {2}, {3}, {0, 1}, {1, 2}, {2, 3}, {-2, -1}}; + for (auto n : {1, 3}) { + for (auto c : {1, 2}) { + for (auto h : {1, 3}) { + for (auto w : {1, 3}) { + for (bool keep_dim : {false, true}) { + for (auto dim : reduce_dim) { + auto x_dims = DDim(std::vector({n, c, h, w})); + std::unique_ptr tester( + new ReduceMeanComputeTester( + place, "def", dim, keep_dim, x_dims)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } + } +} + +TEST(ReduceMean, precision) { +// #ifdef LITE_WITH_X86 +// Place place(TARGET(kX86)); +// #endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_reduce_mean(place); +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/stack_compute_test.cc b/lite/tests/kernels/stack_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7fcefad65e3244ffc6e3296ac28fb917dbab1920 --- /dev/null +++ b/lite/tests/kernels/stack_compute_test.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +void stack(std::vector x, lite::Tensor* y, int axis) { + if (axis < 0) axis += (x[0]->dims().size() + 1); + int n = x.size(); + auto* y_data = y->mutable_data(); + std::vector x_datas(n); + for (int i = 0; i < n; i++) x_datas[i] = x[i]->data(); + + int pre = 1, post = 1; + auto& dim = x[0]->dims(); + for (auto i = 0; i < axis; ++i) pre *= dim[i]; + for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; + + auto x_data_arr = x_datas.data(); + + size_t x_offset = 0; + size_t y_offset = 0; + for (int i = 0; i < pre; i++) { + for (int j = 0; j < n; j++) { + std::memcpy( + y_data + y_offset, x_data_arr[j] + x_offset, post * sizeof(float)); + y_offset += post; + } + x_offset += post; + } +} + +class StackComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string input1_ = "X1"; + std::string input2_ = "X2"; + std::string output_ = "Out"; + int axis_ = 0; + DDim dims_{{1, 5, 6, 7}}; + + public: + StackComputeTester(const Place& place, const std::string& alias, float axis) + : TestCase(place, alias), axis_(axis) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + std::vector x; + x.emplace_back(scope->FindTensor(input1_)); + x.emplace_back(scope->FindTensor(input2_)); + auto input_dims = x[0]->dims(); + int rank = input_dims.size(); + if (axis_ < 0) axis_ += (rank + 1); + auto vec = input_dims.Vectorize(); + vec.insert(vec.begin() + axis_, x.size()); + out->Resize(vec); + stack(x, out, axis_); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("stack"); + op_desc->SetInput("X", {input1_, input2_}); + op_desc->SetOutput("Out", {output_}); + op_desc->SetAttr("axis", axis_); + } + + void PrepareData() override { + std::vector data(dims_.production()); + + for (int i = 0; i < dims_.production(); i++) { + data[i] = i * 1.01; + } + + SetCommonTensor(input1_, dims_, data.data()); + SetCommonTensor(input2_, dims_, data.data()); + } +}; + +void test_stack(Place place) { + for (float axis : {0, 1, 3}) { + std::unique_ptr tester( + new StackComputeTester(place, "def", axis)); + arena::Arena arena(std::move(tester), place, 2e-4); + arena.TestPrecision(); + } +} + +TEST(Stack, precision) { +// #ifdef LITE_WITH_X86 +// Place place(TARGET(kX86)); +// #endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_stack(place); +#endif +} + +} // namespace lite +} // namespace paddle