提交 819ccb8a 编写于 作者: Y yongqiang

fix transpose error. test=develop

上级 d6acf017
......@@ -25,26 +25,6 @@ namespace lite {
namespace kernels {
namespace arm {
template <typename Dtype>
void trans_basic(const int count,
const Dtype* din,
const int* permute_order,
const int* old_steps,
const int* new_steps,
const int num_axes,
Dtype* dout) {
for (int i = 0; i < count; ++i) {
int old_idx = 0;
int idx = i;
for (int j = 0; j < num_axes; ++j) {
int order = permute_order[j];
old_idx += (idx / new_steps[j]) * old_steps[order];
idx %= new_steps[j];
}
dout[i] = din[old_idx];
}
}
template <typename Dtype>
void transpose_mat(const Dtype* din,
Dtype* dout,
......@@ -201,6 +181,61 @@ void TransposeCompute::PrepareForRun() {
_old_steps = get_stride(input->dims());
}
}
template <typename Dtype>
void TransposeCompute_(const std::vector<int>& axis,
const lite::Tensor* input,
lite::Tensor* output) {
// const Dtype *input_ptr = input->data<Dtype>();
const Dtype* input_ptr = input->data<float>();
Dtype* output_ptr = output->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6.
const DDim& in_dim = input->dims();
const DDim& out_dim = output->dims();
// precompute inverted output dim and strides
size_t rout_dim[6], strides[6];
int permute = axis.size(); // permute must >=2 && <= 6.
for (int i = 0; i < permute; ++i) {
int k = permute - 1 - i;
strides[k] = 1;
for (int j = axis[i] + 1; j < permute; ++j) {
strides[k] *= in_dim[j];
}
rout_dim[k] = out_dim[i];
}
// unroll the first 2 dimensions
int reamin_dim = 1;
for (int i = 2; i < out_dim.size(); ++i) {
reamin_dim *= out_dim[i];
}
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < out_dim[0]; ++batch) {
for (int j = 0; j < out_dim[1]; ++j) {
size_t offset = batch * strides[permute - 1] + j * strides[permute - 2];
Dtype* out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim;
int indics[4] = {0, 0, 0, 0};
for (int k = 0; k < reamin_dim; ++k) {
out_ptr[k] = input_ptr[offset];
indics[0] += 1;
offset += strides[0];
for (int p = 0; p < permute - 3; ++p) {
if (indics[p] == rout_dim[p]) {
indics[p + 1] += 1;
indics[p] = 0;
offset += strides[p + 1];
offset -= rout_dim[p] * strides[p];
} else {
break;
}
}
}
}
}
}
// Transpose
void TransposeCompute::Run() {
auto& param = Param<operators::TransposeParam>();
......@@ -220,13 +255,7 @@ void TransposeCompute::Run() {
if (trans_mat) {
transpose_mat(din, dout, _trans_num, _trans_w, _trans_h);
} else {
trans_basic(output->numel(),
din,
param.axis.data(),
_old_steps.data(),
_new_steps.data(),
input->dims().size(),
dout);
TransposeCompute_<float>(axis, input, output);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册