diff --git a/lite/kernels/mlu/bridges/argmax_op_test.cc b/lite/kernels/mlu/bridges/argmax_op_test.cc index d600650190ac538ea1c68d7fca4c79f900315a18..8c6915d24873754b07c4724597f541d283858565 100644 --- a/lite/kernels/mlu/bridges/argmax_op_test.cc +++ b/lite/kernels/mlu/bridges/argmax_op_test.cc @@ -100,13 +100,13 @@ void test_argmax(const std::vector& input_shape, int axis) { Tensor input_x; input_x.Resize(DDim(input_shape)); // change input layout from NCHW to NHWC - transpose(x->mutable_data(), - input_x.mutable_data(), - {static_cast(input_shape[0]), - static_cast(input_shape[1]), - static_cast(input_shape[2]), - static_cast(input_shape[3])}, - {0, 2, 3, 1}); + transpose(x->mutable_data(), + input_x.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[1]), + static_cast(input_shape[2]), + static_cast(input_shape[3])}, + {0, 2, 3, 1}); x->CopyDataFrom(input_x); LaunchOp(op, {x_var_name}, {out_var_name}); @@ -117,13 +117,13 @@ void test_argmax(const std::vector& input_shape, int axis) { Tensor output_trans; output_trans.Resize(out_shape); // Change output layout from NHWC to NCHW - transpose(out_data, - output_trans.mutable_data(), - {static_cast(out_shape[0]), - static_cast(out_shape[2]), - static_cast(out_shape[3]), - static_cast(out_shape[1])}, - {0, 3, 1, 2}); + transpose(out_data, + output_trans.mutable_data(), + {static_cast(out_shape[0]), + static_cast(out_shape[2]), + static_cast(out_shape[3]), + static_cast(out_shape[1])}, + {0, 3, 1, 2}); out_data = output_trans.mutable_data(); for (int i = 0; i < out->dims().production(); i++) { diff --git a/lite/kernels/mlu/bridges/gather_op_test.cc b/lite/kernels/mlu/bridges/gather_op_test.cc index f9b2153ca51f8a2a9971a0e55184fbcd5f4625d9..413de7c9d7fda750b387c2daa21ef1e40e7982c7 100644 --- a/lite/kernels/mlu/bridges/gather_op_test.cc +++ b/lite/kernels/mlu/bridges/gather_op_test.cc @@ -93,13 +93,13 @@ void test_gather() { Tensor input; input.Resize({5, 4, 3, 2}); - transpose(x->mutable_data(), - input.mutable_data(), - {static_cast(5), - static_cast(4), - static_cast(3), - static_cast(2)}, - {0, 2, 3, 1}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(5), + static_cast(4), + static_cast(3), + static_cast(2)}, + {0, 2, 3, 1}); x->CopyDataFrom(input); LaunchOp(op, {x_var_name, index_var_name}, {out_var_name}); @@ -109,13 +109,13 @@ void test_gather() { Tensor output; output.Resize(out->dims()); - transpose(out_data, - output.mutable_data(), - {static_cast(out->dims()[0]), - static_cast(out->dims()[2]), - static_cast(out->dims()[3]), - static_cast(out->dims()[1])}, - {0, 3, 1, 2}); + transpose(out_data, + output.mutable_data(), + {static_cast(out->dims()[0]), + static_cast(out->dims()[2]), + static_cast(out->dims()[3]), + static_cast(out->dims()[1])}, + {0, 3, 1, 2}); out_data = output.mutable_data(); for (int i = 0; i < out->dims().production(); i++) { VLOG(5) << i; diff --git a/lite/kernels/mlu/bridges/layout_op_test.cc b/lite/kernels/mlu/bridges/layout_op_test.cc index a3a39d9177f7d7bfe7f9eb59081ada0d05fb616d..69b905b0750fe99e29c6aaa9bffdc9f20229a239 100644 --- a/lite/kernels/mlu/bridges/layout_op_test.cc +++ b/lite/kernels/mlu/bridges/layout_op_test.cc @@ -50,38 +50,38 @@ void test_layout_NHWC2NCHW(std::vector input_shape) { input.Resize(DDim(input_shape)); switch (input_shape.size()) { case 2: - transpose( + transpose( x->mutable_data(), input.mutable_data(), {static_cast(input_shape[0]), static_cast(input_shape[1])}, {0, 1}); break; case 3: - transpose(x->mutable_data(), - input.mutable_data(), - {static_cast(input_shape[0]), - static_cast(input_shape[2]), - static_cast(input_shape[1])}, - {0, 2, 1}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[2]), + static_cast(input_shape[1])}, + {0, 2, 1}); break; case 4: - transpose(x->mutable_data(), - input.mutable_data(), - {static_cast(input_shape[0]), - static_cast(input_shape[2]), - static_cast(input_shape[3]), - static_cast(input_shape[1])}, - {0, 3, 1, 2}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[2]), + static_cast(input_shape[3]), + static_cast(input_shape[1])}, + {0, 3, 1, 2}); break; case 5: - transpose(x->mutable_data(), - input.mutable_data(), - {static_cast(input_shape[0]), - static_cast(input_shape[2]), - static_cast(input_shape[3]), - static_cast(input_shape[4]), - static_cast(input_shape[1])}, - {0, 4, 1, 2, 3}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[2]), + static_cast(input_shape[3]), + static_cast(input_shape[4]), + static_cast(input_shape[1])}, + {0, 4, 1, 2, 3}); break; default: CHECK(0) << "Unsupport"; @@ -123,38 +123,38 @@ void test_layout_NCHW2NHWC(std::vector input_shape) { input.Resize(DDim(input_shape)); switch (input_shape.size()) { case 2: - transpose( + transpose( x->mutable_data(), input.mutable_data(), {static_cast(input_shape[0]), static_cast(input_shape[1])}, {0, 1}); break; case 3: - transpose(x->mutable_data(), - input.mutable_data(), - {static_cast(input_shape[0]), - static_cast(input_shape[1]), - static_cast(input_shape[2])}, - {0, 2, 1}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[1]), + static_cast(input_shape[2])}, + {0, 2, 1}); break; case 4: - transpose(x->mutable_data(), - input.mutable_data(), - {static_cast(input_shape[0]), - static_cast(input_shape[1]), - static_cast(input_shape[2]), - static_cast(input_shape[3])}, - {0, 2, 3, 1}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[1]), + static_cast(input_shape[2]), + static_cast(input_shape[3])}, + {0, 2, 3, 1}); break; case 5: - transpose(x->mutable_data(), - input.mutable_data(), - {static_cast(input_shape[0]), - static_cast(input_shape[1]), - static_cast(input_shape[2]), - static_cast(input_shape[3]), - static_cast(input_shape[4])}, - {0, 2, 3, 4, 1}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[1]), + static_cast(input_shape[2]), + static_cast(input_shape[3]), + static_cast(input_shape[4])}, + {0, 2, 3, 4, 1}); break; default: CHECK(0) << "Unsupport"; diff --git a/lite/kernels/mlu/bridges/split_op_test.cc b/lite/kernels/mlu/bridges/split_op_test.cc index 18bd74ea94ec46a7c01d39cab883ec325cfc0fd0..a44a45504036e9ef6199e9d2b534aa3dde63bb01 100644 --- a/lite/kernels/mlu/bridges/split_op_test.cc +++ b/lite/kernels/mlu/bridges/split_op_test.cc @@ -135,13 +135,13 @@ void test_split(int bs, Tensor input; input.Resize({bs, ic, ih, iw}); - transpose(x->mutable_data(), - input.mutable_data(), - {static_cast(bs), - static_cast(ic), - static_cast(ih), - static_cast(iw)}, - {0, 2, 3, 1}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(bs), + static_cast(ic), + static_cast(ih), + static_cast(iw)}, + {0, 2, 3, 1}); x->CopyDataFrom(input); LaunchOp(op, {x_var_name}, {out_var_name_1, out_var_name_2}); @@ -154,20 +154,20 @@ void test_split(int bs, Tensor output1, output2; output1.Resize(out_1->dims()); output2.Resize(out_2->dims()); - transpose(out_data_1, - output1.mutable_data(), - {static_cast(out_1->dims()[0]), - static_cast(out_1->dims()[2]), - static_cast(out_1->dims()[3]), - static_cast(out_1->dims()[1])}, - {0, 3, 1, 2}); - transpose(out_data_2, - output2.mutable_data(), - {static_cast(out_2->dims()[0]), - static_cast(out_2->dims()[2]), - static_cast(out_2->dims()[3]), - static_cast(out_2->dims()[1])}, - {0, 3, 1, 2}); + transpose(out_data_1, + output1.mutable_data(), + {static_cast(out_1->dims()[0]), + static_cast(out_1->dims()[2]), + static_cast(out_1->dims()[3]), + static_cast(out_1->dims()[1])}, + {0, 3, 1, 2}); + transpose(out_data_2, + output2.mutable_data(), + {static_cast(out_2->dims()[0]), + static_cast(out_2->dims()[2]), + static_cast(out_2->dims()[3]), + static_cast(out_2->dims()[1])}, + {0, 3, 1, 2}); out_data_1 = output1.mutable_data(); out_data_2 = output2.mutable_data(); for (int i = 0; i < out_1->dims().production(); i++) { diff --git a/lite/kernels/mlu/bridges/utility.cc b/lite/kernels/mlu/bridges/utility.cc index e04b22843a837e0b9922ebebe20de01e302b81da..f8d78f21d9e35f3fe3e12f146086f467ff111d54 100644 --- a/lite/kernels/mlu/bridges/utility.cc +++ b/lite/kernels/mlu/bridges/utility.cc @@ -36,31 +36,6 @@ void transpose2d(float* input_data, } } -void transpose(float* input_data, - float* output_data, - std::vector input_shape, - std::vector axis) { - int old_index = -1; - int new_index = -1; - int dim[4] = {0}; - std::vector shape = input_shape; - for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { - for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { - for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) { - for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) { - old_index = dim[0] * shape[1] * shape[2] * shape[3] + - dim[1] * shape[2] * shape[3] + dim[2] * shape[3] + dim[3]; - new_index = - dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * shape[axis[3]] + - dim[axis[1]] * shape[axis[2]] * shape[axis[3]] + - dim[axis[2]] * shape[axis[3]] + dim[axis[3]]; - output_data[new_index] = input_data[old_index]; - } - } - } - } -} - void dequant(float* dst, int8_t* src, size_t size, float scale) { for (size_t i = 0; i < size; ++i) { dst[i] = static_cast(src[i]) * scale; diff --git a/lite/kernels/mlu/bridges/utility.h b/lite/kernels/mlu/bridges/utility.h index b75038d9d872528949bbffc0b0743511d9669385..78f862c0d305b1fb7db7f75f4cf72e6e322ea64b 100644 --- a/lite/kernels/mlu/bridges/utility.h +++ b/lite/kernels/mlu/bridges/utility.h @@ -34,15 +34,10 @@ namespace mlu { void transpose2d(float* input_data, float* output_data, std::vector input_shape); -template -void transpose(dtype input_data, - dtype output_data, - std::vector input_shape, - std::vector axis); template -void transpose(dtype input_data, - dtype output_data, +void transpose(dtype* input_data, + dtype* output_data, std::vector input_shape, std::vector axis) { int old_index = -1; @@ -89,11 +84,6 @@ void transpose(dtype input_data, } } -void transpose(float* input_data, - float* output_data, - std::vector input_shape, - std::vector axis); - inline int scale2position(float scale) { return std::floor(-std::log2(scale)); } void dequant(float* dst, int8_t* src, size_t size, float scale);