diff --git a/mace/ops/common/transpose.h b/mace/ops/common/transpose.h index 052fc0edcb4263cfa46c922ad4d998594f4651c7..e8e9b1992694fb0db1045b621622a8a87cd96b0b 100644 --- a/mace/ops/common/transpose.h +++ b/mace/ops/common/transpose.h @@ -225,16 +225,22 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool, index_t tile_size = std::max(static_cast(1), static_cast(std::sqrt( 8 * 1024 / channel))); - for (index_t i = 0; i < height; i += tile_size) { - for (index_t j = 0; j < width; j += tile_size) { - index_t end_i = std::min(i + tile_size, height); - index_t end_j = std::min(j + tile_size, width); - for (index_t tile_i = i; tile_i < end_i; ++tile_i) { - for (index_t tile_j = j; tile_j < end_j; ++tile_j) { - auto output_ptr = output + (tile_j * stride_i + tile_i) * channel; - auto input_ptr = input + (tile_i * stride_j + tile_j) * channel; - for (index_t k = 0; k < channel; ++k) { - output_ptr[k] = input_ptr[k]; + for (index_t b = 0; b < input_shape[0]; ++b) { + auto input_offset = input + b * batch_size; + auto output_offset = output + b * batch_size; + for (index_t i = 0; i < height; i += tile_size) { + for (index_t j = 0; j < width; j += tile_size) { + index_t end_i = std::min(i + tile_size, height); + index_t end_j = std::min(j + tile_size, width); + for (index_t tile_i = i; tile_i < end_i; ++tile_i) { + for (index_t tile_j = j; tile_j < end_j; ++tile_j) { + auto output_ptr = + output_offset + (tile_j * stride_i + tile_i) * channel; + auto input_ptr = + input_offset + (tile_i * stride_j + tile_j) * channel; + for (index_t k = 0; k < channel; ++k) { + output_ptr[k] = input_ptr[k]; + } } } }