From ed5ec09e32461104427cb361dc7de788d8547713 Mon Sep 17 00:00:00 2001 From: yuanshuai Date: Tue, 18 Jun 2019 00:52:15 +0000 Subject: [PATCH] [LITE][ARM] Add transpose, transpose2 operator, kernel of arm cpu. test=develop --- paddle/fluid/lite/kernels/arm/CMakeLists.txt | 3 + .../lite/kernels/arm/transpose_compute.cc | 173 +++++++++++++++ .../lite/kernels/arm/transpose_compute.h | 48 ++++ .../kernels/arm/transpose_compute_test.cc | 205 ++++++++++++++++++ paddle/fluid/lite/operators/CMakeLists.txt | 3 + paddle/fluid/lite/operators/op_params.h | 9 + paddle/fluid/lite/operators/transpose_op.cc | 165 ++++++++++++++ paddle/fluid/lite/operators/transpose_op.h | 66 ++++++ .../fluid/lite/operators/transpose_op_test.cc | 93 ++++++++ paddle/fluid/lite/tools/build.sh | 2 +- 10 files changed, 766 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/lite/kernels/arm/transpose_compute.cc create mode 100644 paddle/fluid/lite/kernels/arm/transpose_compute.h create mode 100644 paddle/fluid/lite/kernels/arm/transpose_compute_test.cc create mode 100644 paddle/fluid/lite/operators/transpose_op.cc create mode 100644 paddle/fluid/lite/operators/transpose_op.h create mode 100644 paddle/fluid/lite/operators/transpose_op_test.cc diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index d072f2377..040b80c11 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -16,6 +16,7 @@ cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_a cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(concat_compute_arm SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(dropout_compute_arm SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(transpose_compute_arm SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_activation_compute_arm SRCS activation_compute_test.cc DEPS activation_compute_arm) @@ -29,6 +30,7 @@ lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm) lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm) +lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm) set(arm_kernels fc_compute_arm @@ -43,6 +45,7 @@ set(arm_kernels split_compute_arm concat_compute_arm dropout_compute_arm + transpose_compute_arm ) set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") diff --git a/paddle/fluid/lite/kernels/arm/transpose_compute.cc b/paddle/fluid/lite/kernels/arm/transpose_compute.cc new file mode 100644 index 000000000..368716c36 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/transpose_compute.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/transpose_compute.h" +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +bool IsShuffleChannel(const std::vector &axis) { + bool is_shuffle_channel = true; + if (axis.size() > 2 && axis[0] == 0 && axis[1] == 2 && axis[2] == 1) { + for (int i = 3; i < axis.size(); ++i) { + if (axis[i] != i) { + is_shuffle_channel = false; + break; + } + } + } else { + return false; + } + return is_shuffle_channel; +} + +template +void ShuffleChannelCompute(const std::vector &axis, + const lite::Tensor *input, lite::Tensor *output) { + const Dtype *input_ptr = input->data(); + Dtype *output_ptr = output->mutable_data(); + // input and output's shape dimension must >= 2 && <= 6. + const DDim &in_dim = input->dims(); + const DDim &out_dim = output->dims(); + size_t offset = 1; + for (int i = 3; i < axis.size(); ++i) { + offset *= in_dim[i]; + } + +#pragma omp parallel for collapse(3) + for (int batch = 0; batch < out_dim[0]; ++batch) { + for (int c1 = 0; c1 < out_dim[1]; ++c1) { + for (int c2 = 0; c2 < out_dim[2]; ++c2) { + size_t out_offset = + ((batch * out_dim[1] + c1) * out_dim[2] + c2) * offset; + size_t in_offset = ((batch * in_dim[1] + c2) * in_dim[2] + c1) * offset; + memcpy(output_ptr + out_offset, input_ptr + in_offset, + offset * sizeof(Dtype)); + } + } + } +} + +template +void TransposeCompute_(const std::vector &axis, const lite::Tensor *input, + lite::Tensor *output) { + // const Dtype *input_ptr = input->data(); + const Dtype *input_ptr = input->data(); + Dtype *output_ptr = output->mutable_data(); + + // input and output's shape dimension must >= 2 && <= 6. + const DDim &in_dim = input->dims(); + const DDim &out_dim = output->dims(); + + // precompute inverted output dim and strides + size_t rout_dim[6], strides[6]; + int permute = axis.size(); // permute must >=2 && <= 6. + for (int i = 0; i < permute; ++i) { + int k = permute - 1 - i; + strides[k] = 1; + for (int j = axis[i] + 1; j < permute; ++j) { + strides[k] *= in_dim[j]; + } + rout_dim[k] = out_dim[i]; + } + + // unroll the first 2 dimensions + int reamin_dim = 1; + for (int i = 2; i < out_dim.size(); ++i) { + reamin_dim *= out_dim[i]; + } + +#pragma omp parallel for collapse(2) + for (int batch = 0; batch < out_dim[0]; ++batch) { + for (int j = 0; j < out_dim[1]; ++j) { + size_t offset = batch * strides[permute - 1] + j * strides[permute - 2]; + Dtype *out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim; + int indics[4] = {0, 0, 0, 0}; + for (int k = 0; k < reamin_dim; ++k) { + out_ptr[k] = input_ptr[offset]; + indics[0] += 1; + offset += strides[0]; + for (int p = 0; p < permute - 3; ++p) { + if (indics[p] == rout_dim[p]) { + indics[p + 1] += 1; + indics[p] = 0; + offset += strides[p + 1]; + offset -= rout_dim[p] * strides[p]; + } else { + break; + } + } + } + } + } +} + +// Transpose +void TransposeCompute::Run() { + auto ¶m = Param(); + auto *input = param.x; + auto *output = param.output; + const std::vector axis = param.axis; + + bool shuffle_channel = IsShuffleChannel(axis); + if (shuffle_channel) { + ShuffleChannelCompute(axis, input, output); + } else { + TransposeCompute_(axis, input, output); + } + return; +} + +// Transpose2 +void Transpose2Compute::Run() { + auto ¶m = Param(); + auto *input = param.x; + auto *output = param.output; + const std::vector axis = param.axis; + + bool shuffle_channel = IsShuffleChannel(axis); + if (shuffle_channel) { + ShuffleChannelCompute(axis, input, output); + } else { + TransposeCompute_(axis, input, output); + } + return; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +// Transpose +REGISTER_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::TransposeCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +// Transpose2 +REGISTER_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::Transpose2Compute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/transpose_compute.h b/paddle/fluid/lite/kernels/arm/transpose_compute.h new file mode 100644 index 000000000..d8ebb761e --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/transpose_compute.h @@ -0,0 +1,48 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/operators/transpose_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +// Transpose +class TransposeCompute : public KernelLite { + public: + using param_t = operators::TransposeParam; + + void Run() override; + + virtual ~TransposeCompute() = default; +}; + +// Transpose2 +class Transpose2Compute : public KernelLite { + public: + using param_t = operators::TransposeParam; + + void Run() override; + + virtual ~Transpose2Compute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/transpose_compute_test.cc b/paddle/fluid/lite/kernels/arm/transpose_compute_test.cc new file mode 100644 index 000000000..1315556e3 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/transpose_compute_test.cc @@ -0,0 +1,205 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/transpose_compute.h" +#include +#include +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/lite_tensor.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +#define IN(n, c, h, w) \ + input_data[w + h * input_w + c * input_h * input_w + \ + n * input_c * input_h * input_w] +#define OUT(n, c, h, w) \ + output_data[w + h * output_w + c * output_h * output_w + \ + n * output_c * output_h * output_w] +void transpose_compute_ref(const operators::TransposeParam& param) { + const lite::Tensor* input = param.x; + lite::Tensor* output = param.output; + std::vector axis = param.axis; + + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int input_n = input->dims()[0]; + int input_c = input->dims()[1]; + int input_h = input->dims()[2]; + int input_w = input->dims()[3]; + int output_n = output->dims()[0]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + OUT(n, h, w, c) = IN(n, c, h, w); + } + } + } + } +} + +// Transpose +TEST(transpose_arm, init) { + TransposeCompute transpose; + ASSERT_EQ(transpose.precision(), PRECISION(kFloat)); + ASSERT_EQ(transpose.target(), TARGET(kARM)); +} + +TEST(transpose_arm, compute_shape_nchw) { + TransposeCompute transpose; + operators::TransposeParam param; + + std::vector axis{0, 2, 3, 1}; + param.axis = axis; + + lite::Tensor input; + lite::Tensor output; + lite::Tensor output_ref; + + const std::vector input_shape{1, 24, 2, 2}; + const std::vector output_shape{1, 2, 2, 24}; + + DDimLite ddimInput(input_shape); + DDimLite ddimOutput(output_shape); + + input.Resize(ddimInput); + output.Resize(ddimOutput); + output_ref.Resize(ddimOutput); + + for (int i = 0; + i < input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]; + i += 4) { + input.mutable_data()[i] = i; + input.mutable_data()[i + 1] = i + 1; + input.mutable_data()[i + 2] = i + 2; + input.mutable_data()[i + 3] = i + 3; + } + for (int i = 0; + i < input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]; + i += 4) { + } + param.x = &input; + param.output = &output; + + // run transpose_compute + transpose.SetParam(param); + transpose.Run(); + + // run transpose_compute_ref + param.output = &output_ref; + transpose_compute_ref(param); + + auto* output_data = output.data(); + auto* output_ref_data = output_ref.data(); + for (int i = 0; + i < input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]; + i += 4) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } +} + +TEST(transpose, retrive_op) { + auto transpose = + KernelRegistry::Global().Create( + "transpose"); + ASSERT_FALSE(transpose.empty()); + ASSERT_TRUE(transpose.front()); +} + +// Transpose2 +TEST(transpose2_arm, init) { + Transpose2Compute transpose2; + ASSERT_EQ(transpose2.precision(), PRECISION(kFloat)); + ASSERT_EQ(transpose2.target(), TARGET(kARM)); +} + +TEST(transpose2_arm, compute_shape_nchw) { + Transpose2Compute transpose2; + operators::TransposeParam param; + + std::vector axis{0, 2, 3, 1}; + param.axis = axis; + + lite::Tensor input; + lite::Tensor output; + lite::Tensor output_ref; + + const std::vector input_shape{1, 24, 2, 2}; + const std::vector output_shape{1, 2, 2, 24}; + + DDimLite ddimInput(input_shape); + DDimLite ddimOutput(output_shape); + + input.Resize(ddimInput); + output.Resize(ddimOutput); + output_ref.Resize(ddimOutput); + + for (int i = 0; + i < input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]; + i += 4) { + input.mutable_data()[i] = i; + input.mutable_data()[i + 1] = i + 1; + input.mutable_data()[i + 2] = i + 2; + input.mutable_data()[i + 3] = i + 3; + } + for (int i = 0; + i < input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]; + i += 4) { + } + param.x = &input; + param.output = &output; + + // run transpose_compute + transpose2.SetParam(param); + transpose2.Run(); + + // run transpose_compute_ref + param.output = &output_ref; + transpose_compute_ref(param); + + auto* output_data = output.data(); + auto* output_ref_data = output_ref.data(); + for (int i = 0; + i < input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]; + i += 4) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } +} + +TEST(transpose2, retrive_op) { + auto transpose2 = + KernelRegistry::Global().Create( + "transpose2"); + ASSERT_FALSE(transpose2.empty()); + ASSERT_TRUE(transpose2.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index ba2318138..004e86175 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -22,6 +22,7 @@ cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framewo cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS}) +cc_library(transpose_op_lite SRCS transpose_op.cc DEPS ${op_DEPS}) set(ops_lite conv_op_lite @@ -44,6 +45,7 @@ set(ops_lite dropout_op_lite concat_op_lite split_op_lite + transpose_op_lite PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc @@ -61,3 +63,4 @@ lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memo lite_cc_test(test_fusion_elementwise_activation_ops_lite SRCS fusion_elementwise_activation_ops_test.cc DEPS fusion_elementwise_activation_ops_lite memory_lite) +lite_cc_test(test_transpose_op_lite SRCS transpose_op_test.cc DEPS transpose_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 25d27f036..0cc1e6b78 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -203,6 +203,15 @@ struct SplitParam { std::vector sections; }; +// For Transpose op +struct TransposeParam { + const lite::Tensor* x{}; + lite::Tensor* output{}; + std::vector axis; + bool use_mkldnn{false}; + std::string data_format{"AnyLayout"}; +}; + /// ----------------------- element wise operators ---------------------- struct ElementwiseParam { const lite::Tensor* X{}; diff --git a/paddle/fluid/lite/operators/transpose_op.cc b/paddle/fluid/lite/operators/transpose_op.cc new file mode 100644 index 000000000..6b422bbb2 --- /dev/null +++ b/paddle/fluid/lite/operators/transpose_op.cc @@ -0,0 +1,165 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/transpose_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +// Transpose +bool TransposeOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + std::vector axis = param_.axis; + size_t axis_size = axis.size(); + // "The input tensor's rank(%d) should be equal to the axis's size(%d)", + // x_rank, axis_size + CHECK_OR_FALSE(x_rank == axis_size); + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + // Each element of Attribute axis should be a unique value + // range from 0 to (dims - 1), + // where the dims is the axis's size + CHECK_OR_FALSE(axis[i] < static_cast(axis_size) && + ++count[axis[i]] == 1); + } + return true; +} + +bool TransposeOp::InferShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + std::vector axis = param_.axis; + size_t axis_size = axis.size(); + // "The input tensor's rank(%d) should be equal to the axis's size(%d)", + // x_rank, axis_size + CHECK_OR_FALSE(x_rank == axis_size); + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + // Each element of Attribute axis should be a unique value + // range from 0 to (dims - 1), + // where the dims is the axis's size + CHECK_OR_FALSE(axis[i] < static_cast(axis_size) && + ++count[axis[i]] == 1); + } + lite::DDim out_dims(x_dims); + for (size_t i = 0; i < axis_size; i++) { + out_dims[i] = x_dims[axis[i]]; + } + param_.output->Resize(out_dims); + return true; +} + +bool TransposeOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + + CHECK(scope->FindVar(x)); + CHECK(scope->FindVar(out)); + param_.x = GetVar(scope, x); + param_.output = GetMutableVar(scope, out); + + param_.axis = op_desc.GetAttr>("axis"); + if (op_desc.HasAttr("use_mkldnn")) { + param_.use_mkldnn = op_desc.GetAttr("use_mkldnn"); + } + if (op_desc.HasAttr("data_format")) { + param_.data_format = op_desc.GetAttr("data_format"); + } + return true; +} + +// Transpose2 +bool Transpose2Op::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + std::vector axis = param_.axis; + size_t axis_size = axis.size(); + // "The input tensor's rank(%d) should be equal to the axis's size(%d)", + // x_rank, axis_size + CHECK_OR_FALSE(x_rank == axis_size); + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + // Each element of Attribute axis should be a unique value + // range from 0 to (dims - 1), + // where the dims is the axis's size + CHECK_OR_FALSE(axis[i] < static_cast(axis_size) && + ++count[axis[i]] == 1); + } + return true; +} + +bool Transpose2Op::InferShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + std::vector axis = param_.axis; + size_t axis_size = axis.size(); + // "The input tensor's rank(%d) should be equal to the axis's size(%d)", + // x_rank, axis_size + CHECK_OR_FALSE(x_rank == axis_size); + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + // Each element of Attribute axis should be a unique value + // range from 0 to (dims - 1), + // where the dims is the axis's size + CHECK_OR_FALSE(axis[i] < static_cast(axis_size) && + ++count[axis[i]] == 1); + } + lite::DDim out_dims(x_dims); + for (size_t i = 0; i < axis_size; i++) { + out_dims[i] = x_dims[axis[i]]; + } + param_.output->Resize(out_dims); + return true; +} + +bool Transpose2Op::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + + CHECK(scope->FindVar(x)); + CHECK(scope->FindVar(out)); + param_.x = GetVar(scope, x); + param_.output = GetMutableVar(scope, out); + + param_.axis = op_desc.GetAttr>("axis"); + if (op_desc.HasAttr("use_mkldnn")) { + param_.use_mkldnn = op_desc.GetAttr("use_mkldnn"); + } + if (op_desc.HasAttr("data_format")) { + param_.data_format = op_desc.GetAttr("data_format"); + } + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(transpose, paddle::lite::operators::TransposeOp); +REGISTER_LITE_OP(transpose2, paddle::lite::operators::Transpose2Op); diff --git a/paddle/fluid/lite/operators/transpose_op.h b/paddle/fluid/lite/operators/transpose_op.h new file mode 100644 index 000000000..f51acb61e --- /dev/null +++ b/paddle/fluid/lite/operators/transpose_op.h @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +// Transpose +class TransposeOp : public OpLite { + public: + TransposeOp() {} + explicit TransposeOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "transpose"; } + + private: + mutable TransposeParam param_; +}; + +// Transpose2 +class Transpose2Op : public OpLite { + public: + Transpose2Op() {} + explicit Transpose2Op(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "transpose2"; } + + private: + mutable TransposeParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/transpose_op_test.cc b/paddle/fluid/lite/operators/transpose_op_test.cc new file mode 100644 index 000000000..8962c1e49 --- /dev/null +++ b/paddle/fluid/lite/operators/transpose_op_test.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/transpose_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +// Transpose +TEST(transpose_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + const int h = 10; + const int w = 20; + x->Resize(DDim(std::vector({h, w}))); + output->Resize(DDim(std::vector{w, h})); + + // set data + for (int i = 0; i < h * w; i++) { + x->mutable_data()[i] = i; + } + for (int i = 0; i < w * h; i++) { + output->mutable_data()[i] = 0.; + } + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("transpose"); + desc.SetInput("X", {"x"}); + desc.SetOutput("Out", {"output"}); + // axis change for shape in mobilenetssd: [1, 24, 2, 2] => [1, 2, 2, 24] + std::vector axis{0, 2, 3, 1}; + desc.SetAttr("axis", axis); + + TransposeOp transpose("transpose"); + + transpose.SetValidPlaces({Place{TARGET(kARM), PRECISION(kFloat)}}); + transpose.Attach(desc, &scope); +} + +// Transpose2 +TEST(transpose2_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + const int h = 10; + const int w = 20; + x->Resize(DDim(std::vector({h, w}))); + output->Resize(DDim(std::vector{w, h})); + + // set data + for (int i = 0; i < h * w; i++) { + x->mutable_data()[i] = i; + } + for (int i = 0; i < w * h; i++) { + output->mutable_data()[i] = 0.; + } + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("transpose2"); + desc.SetInput("X", {"x"}); + desc.SetOutput("Out", {"output"}); + // axis change for shape in mobilenetssd: [1, 24, 2, 2] => [1, 2, 2, 24] + std::vector axis{0, 2, 3, 1}; + desc.SetAttr("axis", axis); + + Transpose2Op transpose2("transpose2"); + + transpose2.SetValidPlaces({Place{TARGET(kARM), PRECISION(kFloat)}}); + transpose2.Attach(desc, &scope); +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/tools/build.sh b/paddle/fluid/lite/tools/build.sh index f197d8882..f1c58489c 100755 --- a/paddle/fluid/lite/tools/build.sh +++ b/paddle/fluid/lite/tools/build.sh @@ -75,7 +75,7 @@ function build_single { } function build { - make lite_compile_deps -j $NUM_CORES_FOR_COMPILE + make lite_compile_deps -j$NUM_CORES_FOR_COMPILE } # It will eagerly test all lite related unittests. -- GitLab