提交 88513fd0 编写于 作者: D dingminghui 提交者: MaxwellDing

refactor(trans): only use template transpose function

上级 2c8736c1
...@@ -100,13 +100,13 @@ void test_argmax(const std::vector<int64_t>& input_shape, int axis) { ...@@ -100,13 +100,13 @@ void test_argmax(const std::vector<int64_t>& input_shape, int axis) {
Tensor input_x; Tensor input_x;
input_x.Resize(DDim(input_shape)); input_x.Resize(DDim(input_shape));
// change input layout from NCHW to NHWC // change input layout from NCHW to NHWC
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input_x.mutable_data<float>(), input_x.mutable_data<float>(),
{static_cast<int>(input_shape[0]), {static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]), static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]), static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])}, static_cast<int>(input_shape[3])},
{0, 2, 3, 1}); {0, 2, 3, 1});
x->CopyDataFrom(input_x); x->CopyDataFrom(input_x);
LaunchOp(op, {x_var_name}, {out_var_name}); LaunchOp(op, {x_var_name}, {out_var_name});
...@@ -117,13 +117,13 @@ void test_argmax(const std::vector<int64_t>& input_shape, int axis) { ...@@ -117,13 +117,13 @@ void test_argmax(const std::vector<int64_t>& input_shape, int axis) {
Tensor output_trans; Tensor output_trans;
output_trans.Resize(out_shape); output_trans.Resize(out_shape);
// Change output layout from NHWC to NCHW // Change output layout from NHWC to NCHW
transpose<int*>(out_data, transpose<int>(out_data,
output_trans.mutable_data<int>(), output_trans.mutable_data<int>(),
{static_cast<int>(out_shape[0]), {static_cast<int>(out_shape[0]),
static_cast<int>(out_shape[2]), static_cast<int>(out_shape[2]),
static_cast<int>(out_shape[3]), static_cast<int>(out_shape[3]),
static_cast<int>(out_shape[1])}, static_cast<int>(out_shape[1])},
{0, 3, 1, 2}); {0, 3, 1, 2});
out_data = output_trans.mutable_data<int>(); out_data = output_trans.mutable_data<int>();
for (int i = 0; i < out->dims().production(); i++) { for (int i = 0; i < out->dims().production(); i++) {
......
...@@ -93,13 +93,13 @@ void test_gather() { ...@@ -93,13 +93,13 @@ void test_gather() {
Tensor input; Tensor input;
input.Resize({5, 4, 3, 2}); input.Resize({5, 4, 3, 2});
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(5), {static_cast<int>(5),
static_cast<int>(4), static_cast<int>(4),
static_cast<int>(3), static_cast<int>(3),
static_cast<int>(2)}, static_cast<int>(2)},
{0, 2, 3, 1}); {0, 2, 3, 1});
x->CopyDataFrom(input); x->CopyDataFrom(input);
LaunchOp(op, {x_var_name, index_var_name}, {out_var_name}); LaunchOp(op, {x_var_name, index_var_name}, {out_var_name});
...@@ -109,13 +109,13 @@ void test_gather() { ...@@ -109,13 +109,13 @@ void test_gather() {
Tensor output; Tensor output;
output.Resize(out->dims()); output.Resize(out->dims());
transpose<float*>(out_data, transpose<float>(out_data,
output.mutable_data<float>(), output.mutable_data<float>(),
{static_cast<int>(out->dims()[0]), {static_cast<int>(out->dims()[0]),
static_cast<int>(out->dims()[2]), static_cast<int>(out->dims()[2]),
static_cast<int>(out->dims()[3]), static_cast<int>(out->dims()[3]),
static_cast<int>(out->dims()[1])}, static_cast<int>(out->dims()[1])},
{0, 3, 1, 2}); {0, 3, 1, 2});
out_data = output.mutable_data<float>(); out_data = output.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) { for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i; VLOG(5) << i;
......
...@@ -50,38 +50,38 @@ void test_layout_NHWC2NCHW(std::vector<int64_t> input_shape) { ...@@ -50,38 +50,38 @@ void test_layout_NHWC2NCHW(std::vector<int64_t> input_shape) {
input.Resize(DDim(input_shape)); input.Resize(DDim(input_shape));
switch (input_shape.size()) { switch (input_shape.size()) {
case 2: case 2:
transpose<float*>( transpose<float>(
x->mutable_data<float>(), x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), static_cast<int>(input_shape[1])}, {static_cast<int>(input_shape[0]), static_cast<int>(input_shape[1])},
{0, 1}); {0, 1});
break; break;
case 3: case 3:
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), {static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[2]), static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[1])}, static_cast<int>(input_shape[1])},
{0, 2, 1}); {0, 2, 1});
break; break;
case 4: case 4:
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), {static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[2]), static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3]), static_cast<int>(input_shape[3]),
static_cast<int>(input_shape[1])}, static_cast<int>(input_shape[1])},
{0, 3, 1, 2}); {0, 3, 1, 2});
break; break;
case 5: case 5:
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), {static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[2]), static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3]), static_cast<int>(input_shape[3]),
static_cast<int>(input_shape[4]), static_cast<int>(input_shape[4]),
static_cast<int>(input_shape[1])}, static_cast<int>(input_shape[1])},
{0, 4, 1, 2, 3}); {0, 4, 1, 2, 3});
break; break;
default: default:
CHECK(0) << "Unsupport"; CHECK(0) << "Unsupport";
...@@ -123,38 +123,38 @@ void test_layout_NCHW2NHWC(std::vector<int64_t> input_shape) { ...@@ -123,38 +123,38 @@ void test_layout_NCHW2NHWC(std::vector<int64_t> input_shape) {
input.Resize(DDim(input_shape)); input.Resize(DDim(input_shape));
switch (input_shape.size()) { switch (input_shape.size()) {
case 2: case 2:
transpose<float*>( transpose<float>(
x->mutable_data<float>(), x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), static_cast<int>(input_shape[1])}, {static_cast<int>(input_shape[0]), static_cast<int>(input_shape[1])},
{0, 1}); {0, 1});
break; break;
case 3: case 3:
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), {static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]), static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2])}, static_cast<int>(input_shape[2])},
{0, 2, 1}); {0, 2, 1});
break; break;
case 4: case 4:
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), {static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]), static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]), static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])}, static_cast<int>(input_shape[3])},
{0, 2, 3, 1}); {0, 2, 3, 1});
break; break;
case 5: case 5:
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), {static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]), static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]), static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3]), static_cast<int>(input_shape[3]),
static_cast<int>(input_shape[4])}, static_cast<int>(input_shape[4])},
{0, 2, 3, 4, 1}); {0, 2, 3, 4, 1});
break; break;
default: default:
CHECK(0) << "Unsupport"; CHECK(0) << "Unsupport";
......
...@@ -135,13 +135,13 @@ void test_split(int bs, ...@@ -135,13 +135,13 @@ void test_split(int bs,
Tensor input; Tensor input;
input.Resize({bs, ic, ih, iw}); input.Resize({bs, ic, ih, iw});
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(), input.mutable_data<float>(),
{static_cast<int>(bs), {static_cast<int>(bs),
static_cast<int>(ic), static_cast<int>(ic),
static_cast<int>(ih), static_cast<int>(ih),
static_cast<int>(iw)}, static_cast<int>(iw)},
{0, 2, 3, 1}); {0, 2, 3, 1});
x->CopyDataFrom(input); x->CopyDataFrom(input);
LaunchOp(op, {x_var_name}, {out_var_name_1, out_var_name_2}); LaunchOp(op, {x_var_name}, {out_var_name_1, out_var_name_2});
...@@ -154,20 +154,20 @@ void test_split(int bs, ...@@ -154,20 +154,20 @@ void test_split(int bs,
Tensor output1, output2; Tensor output1, output2;
output1.Resize(out_1->dims()); output1.Resize(out_1->dims());
output2.Resize(out_2->dims()); output2.Resize(out_2->dims());
transpose<float*>(out_data_1, transpose<float>(out_data_1,
output1.mutable_data<float>(), output1.mutable_data<float>(),
{static_cast<int>(out_1->dims()[0]), {static_cast<int>(out_1->dims()[0]),
static_cast<int>(out_1->dims()[2]), static_cast<int>(out_1->dims()[2]),
static_cast<int>(out_1->dims()[3]), static_cast<int>(out_1->dims()[3]),
static_cast<int>(out_1->dims()[1])}, static_cast<int>(out_1->dims()[1])},
{0, 3, 1, 2}); {0, 3, 1, 2});
transpose<float*>(out_data_2, transpose<float>(out_data_2,
output2.mutable_data<float>(), output2.mutable_data<float>(),
{static_cast<int>(out_2->dims()[0]), {static_cast<int>(out_2->dims()[0]),
static_cast<int>(out_2->dims()[2]), static_cast<int>(out_2->dims()[2]),
static_cast<int>(out_2->dims()[3]), static_cast<int>(out_2->dims()[3]),
static_cast<int>(out_2->dims()[1])}, static_cast<int>(out_2->dims()[1])},
{0, 3, 1, 2}); {0, 3, 1, 2});
out_data_1 = output1.mutable_data<float>(); out_data_1 = output1.mutable_data<float>();
out_data_2 = output2.mutable_data<float>(); out_data_2 = output2.mutable_data<float>();
for (int i = 0; i < out_1->dims().production(); i++) { for (int i = 0; i < out_1->dims().production(); i++) {
......
...@@ -36,31 +36,6 @@ void transpose2d(float* input_data, ...@@ -36,31 +36,6 @@ void transpose2d(float* input_data,
} }
} }
void transpose(float* input_data,
float* output_data,
std::vector<int> input_shape,
std::vector<int> axis) {
int old_index = -1;
int new_index = -1;
int dim[4] = {0};
std::vector<int> shape = input_shape;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) {
for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) {
old_index = dim[0] * shape[1] * shape[2] * shape[3] +
dim[1] * shape[2] * shape[3] + dim[2] * shape[3] + dim[3];
new_index =
dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * shape[axis[3]] +
dim[axis[1]] * shape[axis[2]] * shape[axis[3]] +
dim[axis[2]] * shape[axis[3]] + dim[axis[3]];
output_data[new_index] = input_data[old_index];
}
}
}
}
}
void dequant(float* dst, int8_t* src, size_t size, float scale) { void dequant(float* dst, int8_t* src, size_t size, float scale) {
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
dst[i] = static_cast<float>(src[i]) * scale; dst[i] = static_cast<float>(src[i]) * scale;
......
...@@ -34,15 +34,10 @@ namespace mlu { ...@@ -34,15 +34,10 @@ namespace mlu {
void transpose2d(float* input_data, void transpose2d(float* input_data,
float* output_data, float* output_data,
std::vector<int> input_shape); std::vector<int> input_shape);
template <typename dtype>
void transpose(dtype input_data,
dtype output_data,
std::vector<int> input_shape,
std::vector<int> axis);
template <typename dtype> template <typename dtype>
void transpose(dtype input_data, void transpose(dtype* input_data,
dtype output_data, dtype* output_data,
std::vector<int> input_shape, std::vector<int> input_shape,
std::vector<int> axis) { std::vector<int> axis) {
int old_index = -1; int old_index = -1;
...@@ -89,11 +84,6 @@ void transpose(dtype input_data, ...@@ -89,11 +84,6 @@ void transpose(dtype input_data,
} }
} }
void transpose(float* input_data,
float* output_data,
std::vector<int> input_shape,
std::vector<int> axis);
inline int scale2position(float scale) { return std::floor(-std::log2(scale)); } inline int scale2position(float scale) { return std::floor(-std::log2(scale)); }
void dequant(float* dst, int8_t* src, size_t size, float scale); void dequant(float* dst, int8_t* src, size_t size, float scale);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册