diff --git a/mace/kernels/transpose.h b/mace/kernels/transpose.h index 87f9c0e2ab1e9115b520f68ff248b85cbede06e8..04e1caed91d59a82c99315c2034ed1519be0ffc1 100644 --- a/mace/kernels/transpose.h +++ b/mace/kernels/transpose.h @@ -122,16 +122,16 @@ struct TransposeFunctor : OpKernel { if (input->dim_size() == 2) { MACE_CHECK(dims_[0] == 1 && dims_[1] == 0, "no need transform"); - index_t stride_i = input_shape[0]; - index_t stride_j = input_shape[1]; - - index_t tile_size = input_shape[0] > 512 || input_shape[1] > 512 - ? 64 : 32; + index_t height = input_shape[0]; + index_t width = input_shape[1]; + index_t stride_i = height; + index_t stride_j = width; + index_t tile_size = height > 512 || width > 512 ? 64 : 32; #pragma omp parallel for collapse(2) - for (index_t i = 0; i < input_shape[0]; i += tile_size) { - for (index_t j = 0; j < input_shape[1]; j += tile_size) { - index_t end_i = std::min(i + tile_size, input_shape[0]); - index_t end_j = std::min(j + tile_size, input_shape[1]); + for (index_t i = 0; i < height; i += tile_size) { + for (index_t j = 0; j < width; j += tile_size) { + index_t end_i = std::min(i + tile_size, height); + index_t end_j = std::min(j + tile_size, width); for (index_t tile_i = i; tile_i < end_i; ++tile_i) { for (index_t tile_j = j; tile_j < end_j; ++tile_j) { output_data[tile_j * stride_i + tile_i] = @@ -144,6 +144,7 @@ struct TransposeFunctor : OpKernel { std::vector transpose_order_from_NHWC_to_NCHW{0, 3, 1, 2}; std::vector transpose_order_from_NCHW_to_NHWC{0, 2, 3, 1}; index_t batch_size = input->dim(1) * input->dim(2) * input->dim(3); + if (dims_ == transpose_order_from_NHWC_to_NCHW && input->dim(3) == 3) { for (index_t b = 0; b < input->dim(0); ++b) { TransposeNHWCToNCHWC3(input_data + b * batch_size, @@ -159,6 +160,30 @@ struct TransposeFunctor : OpKernel { input->dim(2), input->dim(3)); } + } else if (dims_ == std::vector{0, 2, 1, 3}) { + index_t height = input_shape[1]; + index_t width = input_shape[2]; + index_t channel = input_shape[3]; + index_t channel_raw_size = channel * sizeof(T); + index_t stride_i = height; + index_t stride_j = width; + index_t tile_size = std::max(static_cast(1), + static_cast(std::sqrt( + 8 * 1024 / channel))); +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < height; i += tile_size) { + for (index_t j = 0; j < width; j += tile_size) { + index_t end_i = std::min(i + tile_size, height); + index_t end_j = std::min(j + tile_size, width); + for (index_t tile_i = i; tile_i < end_i; ++tile_i) { + for (index_t tile_j = j; tile_j < end_j; ++tile_j) { + memcpy(output_data + (tile_j * stride_i + tile_i) * channel, + input_data + (tile_i * stride_j + tile_j) * channel, + channel_raw_size); + } + } + } + } } else { std::vector in_stride{input_shape[1] * input_shape[2] * input_shape[3], diff --git a/mace/ops/transpose_benchmark.cc b/mace/ops/transpose_benchmark.cc index c5fe98cd4127c6e3a13f14af51aec4ec1f2666ec..1e68a4a98b2a70084ec6f06511641fd20679add2 100644 --- a/mace/ops/transpose_benchmark.cc +++ b/mace/ops/transpose_benchmark.cc @@ -88,6 +88,7 @@ MACE_BM_TRANSPOSE4D(1, 512, 512, 3, 0, 3, 1, 2); MACE_BM_TRANSPOSE4D(1, 2, 512, 512, 0, 2, 3, 1); MACE_BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2); MACE_BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1); +MACE_BM_TRANSPOSE4D(1, 4, 20, 64, 0, 2, 1, 3); MACE_BM_TRANSPOSE2D(128, 128); MACE_BM_TRANSPOSE2D(512, 512); MACE_BM_TRANSPOSE2D(1024, 1024);