From 1dcd83b134cc83a0fcf0522cc1ede8f05c39d783 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Fri, 20 Apr 2018 14:32:15 +0800 Subject: [PATCH] Add transpose 2D --- mace/kernels/transpose.h | 53 ++++++++++++------- mace/ops/transpose.h | 8 +-- mace/ops/transpose_benchmark.cc | 93 +++++++++++++++++++++++++++++++++ mace/ops/transpose_test.cc | 23 ++++++++ 4 files changed, 153 insertions(+), 24 deletions(-) create mode 100644 mace/ops/transpose_benchmark.cc diff --git a/mace/kernels/transpose.h b/mace/kernels/transpose.h index 6854c5f9..b5e029ed 100644 --- a/mace/kernels/transpose.h +++ b/mace/kernels/transpose.h @@ -37,31 +37,44 @@ struct TransposeFunctor { const T *input_data = input->data(); T *output_data = output->mutable_data(); - std::vector - in_stride{input_shape[1] * input_shape[2] * input_shape[3], - input_shape[2] * input_shape[3], input_shape[3], 1}; - std::vector - out_stride{output_shape[1] * output_shape[2] * output_shape[3], - output_shape[2] * output_shape[3], output_shape[3], 1}; + if (input->dim_size() == 2) { + MACE_CHECK(dims_[0] == 1 && dims_[1] == 0, "no need transform"); + index_t stride_i = input_shape[0]; + index_t stride_j = input_shape[1]; + for (int i = 0; i < input_shape[0]; ++i) { + for (int j = 0; j < input_shape[1]; ++j) { + output_data[j * stride_i + i] = input_data[i * stride_j + j]; + } + } + } else if (input->dim_size() == 4) { + std::vector + in_stride{input_shape[1] * input_shape[2] * input_shape[3], + input_shape[2] * input_shape[3], input_shape[3], 1}; + std::vector + out_stride{output_shape[1] * output_shape[2] * output_shape[3], + output_shape[2] * output_shape[3], output_shape[3], 1}; - std::vector idim(4, 0); - std::vector odim(4, 0); - for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) { - for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) { - for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) { - for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) { - idim[dims_[0]] = odim[0]; - idim[dims_[1]] = odim[1]; - idim[dims_[2]] = odim[2]; - idim[dims_[3]] = odim[3]; + std::vector idim(4, 0); + std::vector odim(4, 0); + for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) { + for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) { + for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) { + for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) { + idim[dims_[0]] = odim[0]; + idim[dims_[1]] = odim[1]; + idim[dims_[2]] = odim[2]; + idim[dims_[3]] = odim[3]; - output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1] - + odim[2] * out_stride[2] + odim[3]] = - input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1] - + idim[2] * in_stride[2] + idim[3]]; + output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1] + + odim[2] * out_stride[2] + odim[3]] = + input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1] + + idim[2] * in_stride[2] + idim[3]]; + } } } } + } else { + MACE_NOT_IMPLEMENTED; } } diff --git a/mace/ops/transpose.h b/mace/ops/transpose.h index 45e36fa3..2ec9281c 100644 --- a/mace/ops/transpose.h +++ b/mace/ops/transpose.h @@ -28,16 +28,16 @@ class TransposeOp : public Operator { public: TransposeOp(const OperatorDef &operator_def, Workspace *ws) : Operator(operator_def, ws), - dims_(OperatorBase::GetRepeatedArgument( - "dims")), + dims_(OperatorBase::GetRepeatedArgument("dims")), functor_(dims_) {} bool Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); Tensor *output = this->Output(OUTPUT); const std::vector &input_shape = input->shape(); - MACE_CHECK(input_shape.size() == 4 && dims_.size() == 4, - "rank should be 4"); + MACE_CHECK(input_shape.size() == 4 && dims_.size() == 4 + || input_shape.size() == 2 && dims_.size() == 2, + "rank should be 2 or 4"); std::vector output_shape; for (int i = 0; i < dims_.size(); ++i) { output_shape.push_back(input_shape[dims_[i]]); diff --git a/mace/ops/transpose_benchmark.cc b/mace/ops/transpose_benchmark.cc new file mode 100644 index 00000000..a86549ed --- /dev/null +++ b/mace/ops/transpose_benchmark.cc @@ -0,0 +1,93 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void TransposeBenchmark(int iters, + std::vector shape, + std::vector dims) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", shape); + + OpDefBuilder("Transpose", "TransposeBM") + .Input("Input") + .Output("Output") + .AddIntsArg("dims", dims) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define BM_TRANSPOSE2D_MACRO(H, W, TYPE, DEVICE) \ + static void BM_TRANSPOSE2D_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + TransposeBenchmark(iters, {H, W}, {1, 0}); \ + } \ + BENCHMARK(BM_TRANSPOSE2D_##H##_##W##_##TYPE##_##DEVICE) + +#define BM_TRANSPOSE2D(H, W) \ + BM_TRANSPOSE2D_MACRO(H, W, float, CPU); + +#define BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, TYPE, DEVICE) \ + static void \ + BM_TRANSPOSE4D_##N##_##C##_##H##_##W##_##D0##D1##D2##D3##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + TransposeBenchmark(iters, {N, C, H, W}, {D0, D1, D2, D3}); \ + } \ + BENCHMARK( \ + BM_TRANSPOSE4D_##N##_##C##_##H##_##W##_##D0##D1##D2##D3##_##TYPE##_##DEVICE) + +#define BM_TRANSPOSE4D(N, C, H, W, D0, D1, D2, D3) \ + BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, float, CPU); + +BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2); +BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1); +BM_TRANSPOSE2D(128, 128); +BM_TRANSPOSE2D(512, 512); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/transpose_test.cc b/mace/ops/transpose_test.cc index b1e8cd4e..0faacc91 100644 --- a/mace/ops/transpose_test.cc +++ b/mace/ops/transpose_test.cc @@ -49,6 +49,29 @@ TEST_F(TransposeOpTest, NCHW) { TransposeNCHWTest({1, 64, 48, 128}); } +TEST_F(TransposeOpTest, Rank2) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input", {2, 3}, {1, 2, 3, 4, 5, 6}); + + OpDefBuilder("Transpose", "TransposeNCHWTest") + .Input("Input") + .Output("Output") + .AddIntsArg("dims", {1, 0}) + .Finalize(net.NewOperatorDef()); + + // Run on cpu + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", + {3, 2}, + {1, 4, 2, 5, 3, 6}); + + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + } // namespace test } // namespace ops } // namespace mace -- GitLab