未验证 提交 4ddd9a98 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

[ARM]change arm transpose kernel. develop=test (#3973)

* add op affine_grid. test=develop

* fix format. test=develop

* change arm  transpose kernel. develop=test
上级 89d573e1
......@@ -25,135 +25,209 @@ namespace lite {
namespace kernels {
namespace arm {
bool IsShuffleChannel(const std::vector<int> &axis) {
bool is_shuffle_channel = true;
if (axis.size() > 2 && axis[0] == 0 && axis[1] == 2 && axis[2] == 1) {
for (int i = 3; i < axis.size(); ++i) {
if (axis[i] != i) {
is_shuffle_channel = false;
break;
}
template <typename Dtype>
void trans_basic(const int count,
const Dtype* din,
const int* permute_order,
const int* old_steps,
const int* new_steps,
const int num_axes,
Dtype* dout) {
for (int i = 0; i < count; ++i) {
int old_idx = 0;
int idx = i;
for (int j = 0; j < num_axes; ++j) {
int order = permute_order[j];
old_idx += (idx / new_steps[j]) * old_steps[order];
idx %= new_steps[j];
}
} else {
return false;
dout[i] = din[old_idx];
}
return is_shuffle_channel;
}
template <typename Dtype>
void ShuffleChannelCompute(const std::vector<int> &axis,
const lite::Tensor *input,
lite::Tensor *output) {
const Dtype *input_ptr = input->data<Dtype>();
Dtype *output_ptr = output->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6.
const DDim &in_dim = input->dims();
const DDim &out_dim = output->dims();
size_t offset = 1;
for (int i = 3; i < axis.size(); ++i) {
offset *= in_dim[i];
}
void transpose_mat(const Dtype* din,
Dtype* dout,
const int num,
const int width,
const int height);
void transpose_mat(const float* din,
float* dout,
const int num,
const int width,
const int height) {
int nw = width >> 2;
int nh = height >> 2;
int size_in = width * height;
#pragma omp parallel for collapse(3)
for (int batch = 0; batch < out_dim[0]; ++batch) {
for (int c1 = 0; c1 < out_dim[1]; ++c1) {
for (int c2 = 0; c2 < out_dim[2]; ++c2) {
size_t out_offset =
((batch * out_dim[1] + c1) * out_dim[2] + c2) * offset;
size_t in_offset = ((batch * in_dim[1] + c2) * in_dim[2] + c1) * offset;
memcpy(output_ptr + out_offset,
input_ptr + in_offset,
offset * sizeof(Dtype));
for (int i = 0; i < num; ++i) {
float* ptr_out = dout + i * size_in;
const float* ptr_in = din + i * size_in;
#pragma omp parallel for
for (int h = 0; h < nh; h++) {
const float* ptr_din_row = ptr_in + h * 4 * width;
for (int w = 0; w < nw; w++) {
float* data_out_ptr = ptr_out + w * 4 * height + h * 4;
const float* din0 = ptr_din_row;
const float* din1 = din0 + width;
const float* din2 = din1 + width;
const float* din3 = din2 + width;
float* dout0 = data_out_ptr;
float* dout1 = dout0 + height;
float* dout2 = dout1 + height;
float* dout3 = dout2 + height;
#ifdef __aarch64__
float32x4_t vr0 = vld1q_f32(din0);
float32x4_t vr1 = vld1q_f32(din1);
float32x4_t vr2 = vld1q_f32(din2);
float32x4_t vr3 = vld1q_f32(din3);
float32x4_t re0 = vtrn1q_f32(vr0, vr1);
float32x4_t re1 = vtrn2q_f32(vr0, vr1);
float32x4_t re2 = vtrn1q_f32(vr2, vr3);
float32x4_t re3 = vtrn2q_f32(vr2, vr3);
vst1_f32(dout0, vget_low_f32(re0));
dout0 += 2;
vst1_f32(dout0, vget_low_f32(re2));
vst1_f32(dout1, vget_low_f32(re1));
dout1 += 2;
vst1_f32(dout1, vget_low_f32(re3));
vst1_f32(dout2, vget_high_f32(re0));
dout2 += 2;
vst1_f32(dout2, vget_high_f32(re2));
vst1_f32(dout3, vget_high_f32(re1));
dout3 += 2;
vst1_f32(dout3, vget_high_f32(re3));
#else
asm("vld1.32 {d0, d1}, [%[in0]] \n"
"vld1.32 {d2, d3}, [%[in1]] \n"
"vld1.32 {d4, d5}, [%[in2]] \n"
"vld1.32 {d6, d7}, [%[in3]] \n"
"vtrn.32 q0, q1 \n"
"vtrn.32 q2, q3 \n"
"vswp d1, d4 \n"
"vswp d3, d6 \n"
"vst1.32 {d0, d1}, [%[out0]] \n"
"vst1.32 {d2, d3}, [%[out1]] \n"
"vst1.32 {d4, d5}, [%[out2]] \n"
"vst1.32 {d6, d7}, [%[out3]] \n"
:
: [out0] "r"(dout0),
[out1] "r"(dout1),
[out2] "r"(dout2),
[out3] "r"(dout3),
[in0] "r"(din0),
[in1] "r"(din1),
[in2] "r"(din2),
[in3] "r"(din3)
: "q0", "q1", "q2", "q3");
#endif
ptr_din_row += 4;
}
}
// remian
for (int h = 0; h < height; h++) {
for (int w = nw * 4; w < width; w++) {
const float* data_in_ptr = ptr_in + h * width + w;
float* data_out_ptr = ptr_out + w * height + h;
*data_out_ptr = *data_in_ptr;
}
}
for (int w = 0; w < width; w++) {
for (int h = nh * 4; h < height; h++) {
const float* data_in_ptr = ptr_in + h * width + w;
float* data_out_ptr = ptr_out + w * height + h;
*data_out_ptr = *data_in_ptr;
}
}
}
}
template <typename Dtype>
void TransposeCompute_(const std::vector<int> &axis,
const lite::Tensor *input,
lite::Tensor *output) {
// const Dtype *input_ptr = input->data<Dtype>();
const Dtype *input_ptr = input->data<float>();
Dtype *output_ptr = output->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6.
const DDim &in_dim = input->dims();
const DDim &out_dim = output->dims();
// precompute inverted output dim and strides
size_t rout_dim[6], strides[6];
int permute = axis.size(); // permute must >=2 && <= 6.
for (int i = 0; i < permute; ++i) {
int k = permute - 1 - i;
strides[k] = 1;
for (int j = axis[i] + 1; j < permute; ++j) {
strides[k] *= in_dim[j];
std::vector<int> get_stride(const paddle::lite::DDimLite& dims) {
std::vector<int> data_stride{0};
for (int i = 0; i < dims.size(); ++i) {
data_stride.push_back(dims.count(i, dims.size()));
}
return data_stride;
}
void TransposeCompute::PrepareForRun() {
auto& param = Param<operators::TransposeParam>();
auto* input = param.x;
auto* output = param.output;
int _num_axes = input->dims().size();
CHECK(_num_axes == param.axis.size())
<< "axis size is not match to input dims";
need_trans = false;
for (int i = 0; i < _num_axes; ++i) {
if (param.axis[i] != i) {
need_trans = true;
break;
}
rout_dim[k] = out_dim[i];
}
// unroll the first 2 dimensions
int reamin_dim = 1;
for (int i = 2; i < out_dim.size(); ++i) {
reamin_dim *= out_dim[i];
if (!need_trans) {
return;
}
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < out_dim[0]; ++batch) {
for (int j = 0; j < out_dim[1]; ++j) {
size_t offset = batch * strides[permute - 1] + j * strides[permute - 2];
Dtype *out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim;
int indics[4] = {0, 0, 0, 0};
for (int k = 0; k < reamin_dim; ++k) {
out_ptr[k] = input_ptr[offset];
indics[0] += 1;
offset += strides[0];
for (int p = 0; p < permute - 3; ++p) {
if (indics[p] == rout_dim[p]) {
indics[p + 1] += 1;
indics[p] = 0;
offset += strides[p + 1];
offset -= rout_dim[p] * strides[p];
} else {
break;
}
}
}
std::vector<int> axis_diff;
int j = 0;
for (int i = 0; i < _num_axes; ++i) {
if (param.axis[j] != i) {
axis_diff.push_back(j);
} else {
j++;
}
}
}
for (int i = 0; i < axis_diff.size(); i++) {
}
if (input->dims().count(axis_diff[0], _num_axes) == 1) {
need_trans = false;
return;
}
// Transpose
void TransposeCompute::Run() {
auto &param = Param<operators::TransposeParam>();
auto *input = param.x;
auto *output = param.output;
const std::vector<int> axis = param.axis;
if (axis_diff.size() == 1) {
trans_mat = true;
_trans_num = input->dims().count(0, std::max(axis_diff[0], 0));
_trans_w = input->dims().count(axis_diff[0] + 1, _num_axes);
_trans_h = input->dims()[axis_diff[0]];
bool shuffle_channel = IsShuffleChannel(axis);
if (shuffle_channel) {
ShuffleChannelCompute<float>(axis, input, output);
} else {
TransposeCompute_<float>(axis, input, output);
trans_mat = false;
_new_steps = get_stride(output->dims());
_old_steps = get_stride(input->dims());
}
return;
}
// Transpose2
void Transpose2Compute::Run() {
auto &param = Param<operators::TransposeParam>();
auto *input = param.x;
auto *output = param.output;
// Transpose
void TransposeCompute::Run() {
auto& param = Param<operators::TransposeParam>();
auto* input = param.x;
auto* output = param.output;
const std::vector<int> axis = param.axis;
bool shuffle_channel = IsShuffleChannel(axis);
if (shuffle_channel) {
ShuffleChannelCompute<float>(axis, input, output);
//! only copy the data
if (!need_trans) {
output->CopyDataFrom(*input);
return;
}
const float* din = static_cast<const float*>(input->data<float>());
float* dout = static_cast<float*>(output->mutable_data<float>());
//! transpose the data
if (trans_mat) {
transpose_mat(din, dout, _trans_num, _trans_w, _trans_h);
} else {
TransposeCompute_<float>(axis, input, output);
trans_basic(output->numel(),
din,
param.axis.data(),
_old_steps.data(),
_new_steps.data(),
input->dims().size(),
dout);
}
return;
}
} // namespace arm
......
......@@ -14,6 +14,7 @@
#pragma once
#include <algorithm>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/operators/transpose_op.h"
......@@ -26,19 +27,24 @@ namespace arm {
class TransposeCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::TransposeParam;
void PrepareForRun() override;
void Run() override;
virtual ~TransposeCompute() = default;
private:
bool need_trans = false;
bool trans_mat = false;
int _trans_num;
int _trans_w;
int _trans_h;
std::vector<int> _new_steps;
std::vector<int> _old_steps;
};
// Transpose2
class Transpose2Compute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class Transpose2Compute : public TransposeCompute {
public:
using param_t = operators::TransposeParam;
void Run() override;
virtual ~Transpose2Compute() = default;
};
......
......@@ -107,6 +107,7 @@ TEST(transpose_arm, compute_shape_nchw) {
// run transpose_compute
transpose.SetParam(param);
transpose.PrepareForRun();
transpose.Run();
// run transpose_compute_ref
......@@ -173,6 +174,7 @@ TEST(transpose2_arm, compute_shape_nchw) {
// run transpose_compute
transpose2.SetParam(param);
transpose2.PrepareForRun();
transpose2.Run();
// run transpose_compute_ref
......@@ -183,8 +185,8 @@ TEST(transpose2_arm, compute_shape_nchw) {
auto* output_ref_data = output_ref.data<float>();
for (int i = 0;
i < input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
i += 4) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
i += 1) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 0);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册