diff --git a/lite/kernels/mlu/bridges/utility.h b/lite/kernels/mlu/bridges/utility.h index 38f3c734b39a7e8ce3b8fb17f8375486361502e7..b75038d9d872528949bbffc0b0743511d9669385 100644 --- a/lite/kernels/mlu/bridges/utility.h +++ b/lite/kernels/mlu/bridges/utility.h @@ -47,74 +47,45 @@ void transpose(dtype input_data, std::vector axis) { int old_index = -1; int new_index = -1; - if (input_shape.size() == 2) { - int dim[2] = {0}; - std::vector shape = input_shape; - for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { - for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { - old_index = dim[0] * shape[1] + dim[1]; - new_index = dim[axis[0]] * shape[axis[1]] + dim[axis[1]]; - output_data[new_index] = input_data[old_index]; - } + std::vector shape; + std::vector expand_axis; + if (input_shape.size() < 5) { + for (int i = 0; i < 5 - input_shape.size(); i++) { + shape.push_back(1); + expand_axis.push_back(i); } - } else if (input_shape.size() == 3) { - int dim[3] = {0}; - std::vector shape = input_shape; - for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { - for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { - for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) { - old_index = dim[0] * shape[1] * shape[2] + dim[1] * shape[2] + dim[2]; - new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] + - dim[axis[1]] * shape[axis[2]] + dim[axis[2]]; - output_data[new_index] = input_data[old_index]; - } - } + for (int i = 0; i < input_shape.size(); i++) { + shape.push_back(input_shape[i]); + expand_axis.push_back(axis[i] + 5 - input_shape.size()); } - } else if (input_shape.size() == 4) { - int dim[4] = {0}; - std::vector shape = input_shape; - for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { - for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { - for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) { - for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) { - old_index = dim[0] * shape[1] * shape[2] * shape[3] + - dim[1] * shape[2] * shape[3] + dim[2] * shape[3] + - dim[3]; - new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * - shape[axis[3]] + - dim[axis[1]] * shape[axis[2]] * shape[axis[3]] + - dim[axis[2]] * shape[axis[3]] + dim[axis[3]]; + } else { + shape = input_shape; + expand_axis = axis; + } + int dim[5] = {0}; + for (dim[0] = 0; dim[0] < shape[0]; dim[0]++) { + for (dim[1] = 0; dim[1] < shape[1]; dim[1]++) { + for (dim[2] = 0; dim[2] < shape[2]; dim[2]++) { + for (dim[3] = 0; dim[3] < shape[3]; dim[3]++) { + for (dim[4] = 0; dim[4] < shape[4]; dim[4]++) { + old_index = dim[0] * shape[1] * shape[2] * shape[3] * shape[4] + + dim[1] * shape[2] * shape[3] * shape[4] + + dim[2] * shape[3] * shape[4] + dim[3] * shape[4] + + dim[4]; + new_index = dim[expand_axis[0]] * shape[expand_axis[1]] * + shape[expand_axis[2]] * shape[expand_axis[3]] * + shape[expand_axis[4]] + + dim[expand_axis[1]] * shape[expand_axis[2]] * + shape[expand_axis[3]] * shape[expand_axis[4]] + + dim[expand_axis[2]] * shape[expand_axis[3]] * + shape[expand_axis[4]] + + dim[expand_axis[3]] * shape[expand_axis[4]] + + dim[expand_axis[4]]; output_data[new_index] = input_data[old_index]; } } } } - } else if (input_shape.size() == 5) { - int dim[5] = {0}; - std::vector shape = input_shape; - for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { - for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { - for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) { - for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) { - for (dim[4] = 0; dim[4] < input_shape[4]; dim[4]++) { - old_index = dim[0] * shape[1] * shape[2] * shape[3] * shape[4] + - dim[1] * shape[2] * shape[3] * shape[4] + - dim[2] * shape[3] * shape[4] + dim[3] * shape[4] + - dim[4]; - new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * - shape[axis[3]] * shape[axis[4]] + - dim[axis[1]] * shape[axis[2]] * shape[axis[3]] * - shape[axis[4]] + - dim[axis[2]] * shape[axis[3]] * shape[axis[4]] + - dim[axis[3]] * shape[axis[4]] + dim[axis[4]]; - output_data[new_index] = input_data[old_index]; - } - } - } - } - } - - } else { } }