提交 88513fd0 编写于 作者: D dingminghui 提交者: MaxwellDing

refactor(trans): only use template transpose function

上级 2c8736c1
......@@ -100,7 +100,7 @@ void test_argmax(const std::vector<int64_t>& input_shape, int axis) {
Tensor input_x;
input_x.Resize(DDim(input_shape));
// change input layout from NCHW to NHWC
transpose<float*>(x->mutable_data<float>(),
transpose<float>(x->mutable_data<float>(),
input_x.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
......@@ -117,7 +117,7 @@ void test_argmax(const std::vector<int64_t>& input_shape, int axis) {
Tensor output_trans;
output_trans.Resize(out_shape);
// Change output layout from NHWC to NCHW
transpose<int*>(out_data,
transpose<int>(out_data,
output_trans.mutable_data<int>(),
{static_cast<int>(out_shape[0]),
static_cast<int>(out_shape[2]),
......
......@@ -93,7 +93,7 @@ void test_gather() {
Tensor input;
input.Resize({5, 4, 3, 2});
transpose<float*>(x->mutable_data<float>(),
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(5),
static_cast<int>(4),
......@@ -109,7 +109,7 @@ void test_gather() {
Tensor output;
output.Resize(out->dims());
transpose<float*>(out_data,
transpose<float>(out_data,
output.mutable_data<float>(),
{static_cast<int>(out->dims()[0]),
static_cast<int>(out->dims()[2]),
......
......@@ -50,14 +50,14 @@ void test_layout_NHWC2NCHW(std::vector<int64_t> input_shape) {
input.Resize(DDim(input_shape));
switch (input_shape.size()) {
case 2:
transpose<float*>(
transpose<float>(
x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), static_cast<int>(input_shape[1])},
{0, 1});
break;
case 3:
transpose<float*>(x->mutable_data<float>(),
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[2]),
......@@ -65,7 +65,7 @@ void test_layout_NHWC2NCHW(std::vector<int64_t> input_shape) {
{0, 2, 1});
break;
case 4:
transpose<float*>(x->mutable_data<float>(),
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[2]),
......@@ -74,7 +74,7 @@ void test_layout_NHWC2NCHW(std::vector<int64_t> input_shape) {
{0, 3, 1, 2});
break;
case 5:
transpose<float*>(x->mutable_data<float>(),
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[2]),
......@@ -123,14 +123,14 @@ void test_layout_NCHW2NHWC(std::vector<int64_t> input_shape) {
input.Resize(DDim(input_shape));
switch (input_shape.size()) {
case 2:
transpose<float*>(
transpose<float>(
x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), static_cast<int>(input_shape[1])},
{0, 1});
break;
case 3:
transpose<float*>(x->mutable_data<float>(),
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
......@@ -138,7 +138,7 @@ void test_layout_NCHW2NHWC(std::vector<int64_t> input_shape) {
{0, 2, 1});
break;
case 4:
transpose<float*>(x->mutable_data<float>(),
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
......@@ -147,7 +147,7 @@ void test_layout_NCHW2NHWC(std::vector<int64_t> input_shape) {
{0, 2, 3, 1});
break;
case 5:
transpose<float*>(x->mutable_data<float>(),
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
......
......@@ -135,7 +135,7 @@ void test_split(int bs,
Tensor input;
input.Resize({bs, ic, ih, iw});
transpose<float*>(x->mutable_data<float>(),
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(bs),
static_cast<int>(ic),
......@@ -154,14 +154,14 @@ void test_split(int bs,
Tensor output1, output2;
output1.Resize(out_1->dims());
output2.Resize(out_2->dims());
transpose<float*>(out_data_1,
transpose<float>(out_data_1,
output1.mutable_data<float>(),
{static_cast<int>(out_1->dims()[0]),
static_cast<int>(out_1->dims()[2]),
static_cast<int>(out_1->dims()[3]),
static_cast<int>(out_1->dims()[1])},
{0, 3, 1, 2});
transpose<float*>(out_data_2,
transpose<float>(out_data_2,
output2.mutable_data<float>(),
{static_cast<int>(out_2->dims()[0]),
static_cast<int>(out_2->dims()[2]),
......
......@@ -36,31 +36,6 @@ void transpose2d(float* input_data,
}
}
void transpose(float* input_data,
float* output_data,
std::vector<int> input_shape,
std::vector<int> axis) {
int old_index = -1;
int new_index = -1;
int dim[4] = {0};
std::vector<int> shape = input_shape;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) {
for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) {
old_index = dim[0] * shape[1] * shape[2] * shape[3] +
dim[1] * shape[2] * shape[3] + dim[2] * shape[3] + dim[3];
new_index =
dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * shape[axis[3]] +
dim[axis[1]] * shape[axis[2]] * shape[axis[3]] +
dim[axis[2]] * shape[axis[3]] + dim[axis[3]];
output_data[new_index] = input_data[old_index];
}
}
}
}
}
void dequant(float* dst, int8_t* src, size_t size, float scale) {
for (size_t i = 0; i < size; ++i) {
dst[i] = static_cast<float>(src[i]) * scale;
......
......@@ -34,15 +34,10 @@ namespace mlu {
void transpose2d(float* input_data,
float* output_data,
std::vector<int> input_shape);
template <typename dtype>
void transpose(dtype input_data,
dtype output_data,
std::vector<int> input_shape,
std::vector<int> axis);
template <typename dtype>
void transpose(dtype input_data,
dtype output_data,
void transpose(dtype* input_data,
dtype* output_data,
std::vector<int> input_shape,
std::vector<int> axis) {
int old_index = -1;
......@@ -89,11 +84,6 @@ void transpose(dtype input_data,
}
}
void transpose(float* input_data,
float* output_data,
std::vector<int> input_shape,
std::vector<int> axis);
inline int scale2position(float scale) { return std::floor(-std::log2(scale)); }
void dequant(float* dst, int8_t* src, size_t size, float scale);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册