From a7869e1617952b8753a8e2afb6260d823c631be7 Mon Sep 17 00:00:00 2001 From: liutuo Date: Wed, 5 Jun 2019 19:24:51 +0800 Subject: [PATCH] add support 3d-input transpose --- mace/ops/common/transpose.h | 90 +++++++++++++++++++++++++- mace/ops/transpose.cc | 5 +- test/ccunit/mace/ops/transpose_test.cc | 58 +++++++++++++++++ 3 files changed, 148 insertions(+), 5 deletions(-) diff --git a/mace/ops/common/transpose.h b/mace/ops/common/transpose.h index b7b42490..6a70133c 100644 --- a/mace/ops/common/transpose.h +++ b/mace/ops/common/transpose.h @@ -162,8 +162,9 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool, const std::vector &dst_dims, T *output) { MACE_CHECK((input_shape.size() == 2 && dst_dims.size() == 2) || - (input_shape.size() == 4 && dst_dims.size() == 4), - "Only support 2D or 4D transpose"); + (input_shape.size() == 3 && dst_dims.size() == 3) || + (input_shape.size() == 4 && dst_dims.size() == 4), + "Only support 2D, 3D or 4D transpose"); std::vector output_shape; for (size_t i = 0; i < dst_dims.size(); ++i) { @@ -219,7 +220,7 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool, index_t height = input_shape[1]; index_t width = input_shape[2]; index_t channel = input_shape[3]; - index_t channel_raw_size = channel * sizeof(T); + size_t channel_raw_size = channel * sizeof(T); index_t stride_i = height; index_t stride_j = width; index_t tile_size = std::max(static_cast(1), @@ -266,6 +267,89 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool, } } } + } else if (input_shape.size() == 3) { + const index_t batch = input_shape[0]; + const index_t height = input_shape[1]; + const index_t width = input_shape[2]; + if (dst_dims == std::vector{0, 2, 1}) { + index_t stride_h = height; + index_t stride_w = width; + index_t stride_hw = height * width; + index_t tile_size = height > 512 || width > 512 ? 64 : 32; + + thread_pool->Compute3D([=](index_t start, index_t end, index_t step, + index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t b = start; b < end; b += step) { + for (index_t i = start0; i < end0; i += step0) { + for (index_t j = start1; j < end1; j += step1) { + index_t end_i = std::min(i + tile_size, height); + index_t end_j = std::min(j + tile_size, width); + for (index_t tile_i = i; tile_i < end_i; ++tile_i) { + for (index_t tile_j = j; tile_j < end_j; ++tile_j) { + output[b * stride_hw + tile_j * stride_h + tile_i] = + input[b * stride_hw + tile_i * stride_w + tile_j]; + } + } + } + } + } + }, 0, batch, 1, 0, height, tile_size, 0, width, tile_size); + } else if (dst_dims == std::vector{1, 0, 2}) { + size_t width_raw_size = width * sizeof(T); + thread_pool->Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (int i = start0; i < end0; i += step0) { + for (int j = start1; j < end1; j += step1) { + memcpy(output + (j * batch + i) * width, + input + (i * height + j) * width, + width_raw_size); + } + } + }, 0, batch, 1, 0, height, 1); + } else if (dst_dims == std::vector{1, 2, 0}) { + thread_pool->Compute3D([=](index_t start, index_t end, index_t step, + index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (int i = start; i < end; i += step) { + for (int j = start0; j < end0; j += step0) { + for (int k = start1; k < end1; k += step1) { + output[(j * width + k) * batch + i] = + input[(i * height + j) * width + k]; + } + } + } + }, 0, batch, 1, 0, height, 1, 0, width, 1); + } else if (dst_dims == std::vector{2, 0, 1}) { + thread_pool->Compute3D([=](index_t start, index_t end, index_t step, + index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (int i = start; i < end; i += step) { + for (int j = start0; j < end0; j += step0) { + for (int k = start1; k < end1; k += step1) { + output[(k * batch + i) * height + j] = + input[(i * height + j) * width + k]; + } + } + } + }, 0, batch, 1, 0, height, 1, 0, width, 1); + } else if (dst_dims == std::vector{2, 1, 0}) { + thread_pool->Compute3D([=](index_t start, index_t end, index_t step, + index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (int i = start; i < end; i += step) { + for (int j = start0; j < end0; j += step0) { + for (int k = start1; k < end1; k += step1) { + output[(k * height + j) * batch + i] = + input[(i * height + j) * width + k]; + } + } + } + }, 0, batch, 1, 0, height, 1, 0, width, 1); + } else { + MACE_CHECK(false, "no need to transform."); + } + } else { MACE_NOT_IMPLEMENTED; } diff --git a/mace/ops/transpose.cc b/mace/ops/transpose.cc index 6c6993e0..4eb41e5b 100644 --- a/mace/ops/transpose.cc +++ b/mace/ops/transpose.cc @@ -42,8 +42,9 @@ class TransposeOp : public Operation { Tensor *output = this->Output(0); const std::vector &input_shape = input->shape(); MACE_CHECK((input_shape.size() == 4 && dims_.size() == 4) || - (input_shape.size() == 2 && dims_.size() == 2), - "rank should be 2 or 4"); + (input_shape.size() == 3 && dims_.size() == 3) || + (input_shape.size() == 2 && dims_.size() == 2), + "rank should be 2, 3 or 4"); std::vector output_shape; for (size_t i = 0; i < dims_.size(); ++i) { output_shape.push_back(input_shape[dims_[i]]); diff --git a/test/ccunit/mace/ops/transpose_test.cc b/test/ccunit/mace/ops/transpose_test.cc index 4fc64cc9..f74cc0cc 100644 --- a/test/ccunit/mace/ops/transpose_test.cc +++ b/test/ccunit/mace/ops/transpose_test.cc @@ -101,6 +101,64 @@ TEST_F(TransposeOpTest, Rank2) { *net.GetOutput("Output")); } +namespace { +void Transpose3DTest(const std::vector &input_shape, + const std::vector &input_data, + const std::vector &dest_dims, + const std::vector &expected_shape, + const std::vector &expected_data) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input", + input_shape, + input_data); + + OpDefBuilder("Transpose", "TransposeNCHWTest") + .Input("Input") + .Output("Output") + .AddIntsArg("dims", dest_dims) + .Finalize(net.NewOperatorDef()); + + // Run on cpu + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", expected_shape, + expected_data); + + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + +TEST_F(TransposeOpTest, Rank3) { +Transpose3DTest({2, 3, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {0, 2, 1}, + {2, 2, 3}, + {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}); +Transpose3DTest({2, 3, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {1, 0, 2}, + {3, 2, 2}, + {1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}); +Transpose3DTest({2, 3, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {1, 2, 0}, + {3, 2, 2}, + {1, 7, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12}); +Transpose3DTest({2, 3, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {2, 0, 1}, + {2, 2, 3}, + {1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12}); +Transpose3DTest({2, 3, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {2, 1, 0}, + {2, 3, 2}, + {1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12}); +} + } // namespace test } // namespace ops } // namespace mace -- GitLab