提交 c7909b85 编写于 作者: L luxuhui

fix: fix `Transpose` op's computing error on batch tensor.

N/A
Signed-off-by: NLuxuhui <luxuhui@xiaomi.com>
上级 cfc800b5
...@@ -225,14 +225,19 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool, ...@@ -225,14 +225,19 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool,
index_t tile_size = std::max(static_cast<index_t>(1), index_t tile_size = std::max(static_cast<index_t>(1),
static_cast<index_t>(std::sqrt( static_cast<index_t>(std::sqrt(
8 * 1024 / channel))); 8 * 1024 / channel)));
for (index_t b = 0; b < input_shape[0]; ++b) {
auto input_offset = input + b * batch_size;
auto output_offset = output + b * batch_size;
for (index_t i = 0; i < height; i += tile_size) { for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) { for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height); index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width); index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) { for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) { for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
auto output_ptr = output + (tile_j * stride_i + tile_i) * channel; auto output_ptr =
auto input_ptr = input + (tile_i * stride_j + tile_j) * channel; output_offset + (tile_j * stride_i + tile_i) * channel;
auto input_ptr =
input_offset + (tile_i * stride_j + tile_j) * channel;
for (index_t k = 0; k < channel; ++k) { for (index_t k = 0; k < channel; ++k) {
output_ptr[k] = input_ptr[k]; output_ptr[k] = input_ptr[k];
} }
...@@ -240,6 +245,7 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool, ...@@ -240,6 +245,7 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool,
} }
} }
} }
}
} else { } else {
std::vector<index_t> std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3], in_stride{input_shape[1] * input_shape[2] * input_shape[3],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册