diff --git a/mace/kernels/transpose.h b/mace/kernels/transpose.h index 3f49ee9c4f6198548f2178b70a70a64efa340086..5f42b8ad211a9841c5e225e15013f658b34cd5b7 100644 --- a/mace/kernels/transpose.h +++ b/mace/kernels/transpose.h @@ -15,6 +15,10 @@ #ifndef MACE_KERNELS_TRANSPOSE_H_ #define MACE_KERNELS_TRANSPOSE_H_ +#if defined(MACE_ENABLE_NEON) +#include +#endif + #include #include "mace/core/future.h" @@ -25,6 +29,64 @@ namespace mace { namespace kernels { +static void TransposeNHWCToNCHWC3(const float *input, + float *output, + const index_t height, + const index_t width) { + index_t image_size = height * width; + +#pragma omp parallel for + for (index_t h = 0; h < height; ++h) { + index_t in_offset = h * width * 3; + index_t out_offset = h * width; + + index_t w; + for (w = 0; w + 3 < width; w += 4) { + float32x4x3_t vi = vld3q_f32(input + in_offset); + vst1q_f32(output + out_offset, vi.val[0]); + vst1q_f32(output + out_offset + image_size, vi.val[1]); + vst1q_f32(output + out_offset + image_size * 2, vi.val[2]); + + in_offset += 12; + out_offset += 4; + } + for (; w < width; ++w) { + for (index_t c = 0; c < 3; ++c) { + output[h * width + image_size * c + w] = + input[h * width * 3 + w * 3 + c]; + } + } + } +} + +static void TransposeNCHWToNHWCC2(const float *input, + float *output, + const index_t height, + const index_t width) { + index_t image_size = height * width; +#pragma omp parallel for + for (index_t h = 0; h < height; ++h) { + index_t in_offset = h * width; + index_t out_offset = h * width * 2; + + index_t w; + for (w = 0; w + 3 < width; w += 4) { + float32x4_t vi0 = vld1q_f32(input + in_offset); + float32x4_t vi1 = vld1q_f32(input + in_offset + image_size); + vst2q_f32(output + out_offset, {vi0, vi1}); + + in_offset += 4; + out_offset += 8; + } + for (; w < width; ++w) { + for (index_t c = 0; c < 2; ++c) { + output[h * width * 2 + w * 2 + c] = + input[h * width + image_size * c + w]; + } + } + } +} + template struct TransposeFunctor { explicit TransposeFunctor(const std::vector &dims) : dims_(dims) {} @@ -48,28 +110,48 @@ struct TransposeFunctor { } } } else if (input->dim_size() == 4) { - std::vector - in_stride{input_shape[1] * input_shape[2] * input_shape[3], - input_shape[2] * input_shape[3], input_shape[3], 1}; - std::vector - out_stride{output_shape[1] * output_shape[2] * output_shape[3], - output_shape[2] * output_shape[3], output_shape[3], 1}; - - std::vector idim(4, 0); - std::vector odim(4, 0); - for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) { - for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) { - for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) { - for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) { - idim[dims_[0]] = odim[0]; - idim[dims_[1]] = odim[1]; - idim[dims_[2]] = odim[2]; - idim[dims_[3]] = odim[3]; - - output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1] - + odim[2] * out_stride[2] + odim[3]] = - input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1] - + idim[2] * in_stride[2] + idim[3]]; + std::vector transpose_order_from_NHWC_to_NCHW{0, 3, 1, 2}; + std::vector transpose_order_from_NCHW_to_NHWC{0, 2, 3, 1}; + index_t batch_size = input->dim(1) * input->dim(2) * input->dim(3); + if (dims_ == transpose_order_from_NHWC_to_NCHW && input->dim(3) == 3) { + for (index_t b = 0; b < input->dim(0); ++b) { + TransposeNHWCToNCHWC3(input_data + b * batch_size, + output_data + b * batch_size, + input->dim(1), + input->dim(2)); + } + } else if (dims_ == transpose_order_from_NCHW_to_NHWC + && input->dim(1) == 2) { + for (index_t b = 0; b < input->dim(0); ++b) { + TransposeNCHWToNHWCC2(input_data + b * batch_size, + output_data + b * batch_size, + input->dim(2), + input->dim(3)); + } + } else { + std::vector + in_stride{input_shape[1] * input_shape[2] * input_shape[3], + input_shape[2] * input_shape[3], input_shape[3], 1}; + std::vector + out_stride{output_shape[1] * output_shape[2] * output_shape[3], + output_shape[2] * output_shape[3], output_shape[3], 1}; + + std::vector idim(4, 0); + std::vector odim(4, 0); + for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) { + for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) { + for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) { + for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) { + idim[dims_[0]] = odim[0]; + idim[dims_[1]] = odim[1]; + idim[dims_[2]] = odim[2]; + idim[dims_[3]] = odim[3]; + + output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1] + + odim[2] * out_stride[2] + odim[3]] = + input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1] + + idim[2] * in_stride[2] + idim[3]]; + } } } } diff --git a/mace/ops/transpose_benchmark.cc b/mace/ops/transpose_benchmark.cc index a86549ed9cc4206b00d9276df524e95d491acad7..24e6f2ffe44499de11da8fd2eb22a1010401c6b6 100644 --- a/mace/ops/transpose_benchmark.cc +++ b/mace/ops/transpose_benchmark.cc @@ -83,6 +83,9 @@ void TransposeBenchmark(int iters, #define BM_TRANSPOSE4D(N, C, H, W, D0, D1, D2, D3) \ BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, float, CPU); + +BM_TRANSPOSE4D(1, 512, 512, 3, 0, 3, 1, 2); +BM_TRANSPOSE4D(1, 2, 512, 512, 0, 2, 3, 1); BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2); BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1); BM_TRANSPOSE2D(128, 128); diff --git a/mace/ops/transpose_test.cc b/mace/ops/transpose_test.cc index 0faacc9111c4e904a6bd2a95b44b835376ae9987..3a4b5729e349889f0ad725255f46fa0cbf5a90b9 100644 --- a/mace/ops/transpose_test.cc +++ b/mace/ops/transpose_test.cc @@ -37,16 +37,51 @@ void TransposeNCHWTest(const std::vector &input_shape) { // Run on cpu net.RunOp(); - net.FillNHWCInputToNCHWInput("InputNCHW", "Input"); + net.TransformDataFormat("Input", + DataFormat::NHWC, + "InputNCHW", + DataFormat::NCHW); ExpectTensorNear(*net.GetOutput("InputNCHW"), *net.GetOutput("Output")); } + +void TransposeNHWCTest(const std::vector &input_shape) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", input_shape); + + OpDefBuilder("Transpose", "TransposeNHWCTest") + .Input("Input") + .Output("Output") + .AddIntsArg("dims", {0, 2, 3, 1}) + .Finalize(net.NewOperatorDef()); + + // Run on cpu + net.RunOp(); + + net.TransformDataFormat("Input", + DataFormat::NCHW, + "InputNHWC", + DataFormat::NHWC); + + ExpectTensorNear(*net.GetOutput("InputNHWC"), + *net.GetOutput("Output")); +} } // namespace -TEST_F(TransposeOpTest, NCHW) { +TEST_F(TransposeOpTest, NHWC_to_NCHW) { TransposeNCHWTest({3, 64, 64, 128}); TransposeNCHWTest({1, 64, 48, 128}); + TransposeNCHWTest({1, 512, 512, 3}); + TransposeNCHWTest({2, 512, 512, 3}); +} + +TEST_F(TransposeOpTest, NCHW_to_NHWC) { + TransposeNHWCTest({1, 2, 512, 512}); + TransposeNHWCTest({1, 3, 512, 512}); + TransposeNHWCTest({2, 2, 512, 512}); } TEST_F(TransposeOpTest, Rank2) {