提交 6d16b491 编写于 作者: 李超

Merge branch 'transpose_bug' into 'master'

fix: fix `Transpose` op's computing error on batch tensor.

See merge request applied-machine-learning/sysml/mace!1302
......@@ -225,16 +225,22 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool,
index_t tile_size = std::max(static_cast<index_t>(1),
static_cast<index_t>(std::sqrt(
8 * 1024 / channel)));
for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
auto output_ptr = output + (tile_j * stride_i + tile_i) * channel;
auto input_ptr = input + (tile_i * stride_j + tile_j) * channel;
for (index_t k = 0; k < channel; ++k) {
output_ptr[k] = input_ptr[k];
for (index_t b = 0; b < input_shape[0]; ++b) {
auto input_offset = input + b * batch_size;
auto output_offset = output + b * batch_size;
for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
auto output_ptr =
output_offset + (tile_j * stride_i + tile_i) * channel;
auto input_ptr =
input_offset + (tile_i * stride_j + tile_j) * channel;
for (index_t k = 0; k < channel; ++k) {
output_ptr[k] = input_ptr[k];
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册