提交 cb3f16ff 编写于 作者: J jiaopu 提交者: jackzhang235

Simplified xx function

上级 5f12addc
...@@ -47,74 +47,45 @@ void transpose(dtype input_data, ...@@ -47,74 +47,45 @@ void transpose(dtype input_data,
std::vector<int> axis) { std::vector<int> axis) {
int old_index = -1; int old_index = -1;
int new_index = -1; int new_index = -1;
if (input_shape.size() == 2) { std::vector<int> shape;
int dim[2] = {0}; std::vector<int> expand_axis;
std::vector<int> shape = input_shape; if (input_shape.size() < 5) {
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { for (int i = 0; i < 5 - input_shape.size(); i++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { shape.push_back(1);
old_index = dim[0] * shape[1] + dim[1]; expand_axis.push_back(i);
new_index = dim[axis[0]] * shape[axis[1]] + dim[axis[1]];
output_data[new_index] = input_data[old_index];
}
} }
} else if (input_shape.size() == 3) { for (int i = 0; i < input_shape.size(); i++) {
int dim[3] = {0}; shape.push_back(input_shape[i]);
std::vector<int> shape = input_shape; expand_axis.push_back(axis[i] + 5 - input_shape.size());
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) {
old_index = dim[0] * shape[1] * shape[2] + dim[1] * shape[2] + dim[2];
new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] +
dim[axis[1]] * shape[axis[2]] + dim[axis[2]];
output_data[new_index] = input_data[old_index];
}
}
} }
} else if (input_shape.size() == 4) { } else {
int dim[4] = {0}; shape = input_shape;
std::vector<int> shape = input_shape; expand_axis = axis;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { }
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { int dim[5] = {0};
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) { for (dim[0] = 0; dim[0] < shape[0]; dim[0]++) {
for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) { for (dim[1] = 0; dim[1] < shape[1]; dim[1]++) {
old_index = dim[0] * shape[1] * shape[2] * shape[3] + for (dim[2] = 0; dim[2] < shape[2]; dim[2]++) {
dim[1] * shape[2] * shape[3] + dim[2] * shape[3] + for (dim[3] = 0; dim[3] < shape[3]; dim[3]++) {
dim[3]; for (dim[4] = 0; dim[4] < shape[4]; dim[4]++) {
new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * old_index = dim[0] * shape[1] * shape[2] * shape[3] * shape[4] +
shape[axis[3]] + dim[1] * shape[2] * shape[3] * shape[4] +
dim[axis[1]] * shape[axis[2]] * shape[axis[3]] + dim[2] * shape[3] * shape[4] + dim[3] * shape[4] +
dim[axis[2]] * shape[axis[3]] + dim[axis[3]]; dim[4];
new_index = dim[expand_axis[0]] * shape[expand_axis[1]] *
shape[expand_axis[2]] * shape[expand_axis[3]] *
shape[expand_axis[4]] +
dim[expand_axis[1]] * shape[expand_axis[2]] *
shape[expand_axis[3]] * shape[expand_axis[4]] +
dim[expand_axis[2]] * shape[expand_axis[3]] *
shape[expand_axis[4]] +
dim[expand_axis[3]] * shape[expand_axis[4]] +
dim[expand_axis[4]];
output_data[new_index] = input_data[old_index]; output_data[new_index] = input_data[old_index];
} }
} }
} }
} }
} else if (input_shape.size() == 5) {
int dim[5] = {0};
std::vector<int> shape = input_shape;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) {
for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) {
for (dim[4] = 0; dim[4] < input_shape[4]; dim[4]++) {
old_index = dim[0] * shape[1] * shape[2] * shape[3] * shape[4] +
dim[1] * shape[2] * shape[3] * shape[4] +
dim[2] * shape[3] * shape[4] + dim[3] * shape[4] +
dim[4];
new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] *
shape[axis[3]] * shape[axis[4]] +
dim[axis[1]] * shape[axis[2]] * shape[axis[3]] *
shape[axis[4]] +
dim[axis[2]] * shape[axis[3]] * shape[axis[4]] +
dim[axis[3]] * shape[axis[4]] + dim[axis[4]];
output_data[new_index] = input_data[old_index];
}
}
}
}
}
} else {
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册