提交 92b6051d 编写于 作者: Z zhupy

fix elementwise_add arm kernel and unit test

test=develop
上级 f8fcc594
...@@ -41,15 +41,15 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout, ...@@ -41,15 +41,15 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
float32x4_t diny2 = vld1q_f32(diny_ptr + 8); float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12); float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
float32x4_t vsum0 = vaddq_f32(dinx0, diny0); float32x4_t dinx0 = vaddq_f32(dinx0, diny0);
float32x4_t vsum1 = vaddq_f32(dinx1, diny1); float32x4_t dinx1 = vaddq_f32(dinx1, diny1);
float32x4_t vsum2 = vaddq_f32(dinx2, diny2); float32x4_t dinx2 = vaddq_f32(dinx2, diny2);
float32x4_t vsum3 = vaddq_f32(dinx3, diny3); float32x4_t dinx3 = vaddq_f32(dinx3, diny3);
vst1q_f32(dout_ptr, vsum0); vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, vsum1); vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, vsum2); vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, vsum3); vst1q_f32(dout_ptr + 12, dinx3);
} }
if (remain > 0) { if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4); const float* dinx_ptr = dinx + (cnt << 4);
...@@ -64,6 +64,69 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout, ...@@ -64,6 +64,69 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
} }
} }
template <>
void elementwise_add_axis<float>(const float* dinx, const float* diny,
float* dout, int batch, int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
din2 = vaddq_f32(din2, rb);
din3 = vaddq_f32(din3, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vaddq_f32(din0, diny_data);
din1 = vaddq_f32(din1, diny_data);
vst1q_f32(dout_ptr, r0);
vst1q_f32(dout_ptr + 4, r1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vaddq_f32(din0, rb);
vst1q_f32(dout_ptr, diny_data);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (p = 0; p < remain; p++) {
*dout_ptr = *dinx_ptr + diny_data;
dout_ptr++;
dinx_ptr++;
}
}
}
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -22,6 +22,10 @@ namespace math { ...@@ -22,6 +22,10 @@ namespace math {
template <typename T> template <typename T>
void elementwise_add(const T* dinx, const T* diny, T* dout, int num); void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_add_axis<float>(const T* dinx, const T* diny, T* dout,
int batch, int channels, int num);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -25,8 +25,31 @@ void ElementwiseAddCompute::Run() { ...@@ -25,8 +25,31 @@ void ElementwiseAddCompute::Run() {
const float* x_data = param.X->data<float>(); const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>(); const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>(); float* out_data = param.Out->mutable_data<float>();
int n = param.X->dims().production(); int axis = param.axis;
lite::arm::math::elementwise_add(x_data, y_data, out_data, n); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
if (axis == 0) {
lite::arm::math::elementwise_add(x_data, y_data, out_data,
x_dims.production());
} else {
int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
lite::arm::math::elementwise_add_axis(x_data, y_data, out_data, batch,
channels, num);
}
} }
} // namespace arm } // namespace arm
......
...@@ -41,40 +41,94 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) { ...@@ -41,40 +41,94 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
const dtype* x_data = param.X->data<const dtype>(); const dtype* x_data = param.X->data<const dtype>();
const dtype* y_data = param.Y->data<const dtype>(); const dtype* y_data = param.Y->data<const dtype>();
dtype* out_data = param.Out->mutable_data<dtype>(); dtype* out_data = param.Out->mutable_data<dtype>();
DDim dim = param.X->dims(); auto x_dims = param.X->dims();
ASSERT_EQ(dim.data(), param.Out->dims().data()); auto y_dims = param.Y->dims();
for (int i = 0; i < dim.production(); i++) { int axis = param.axis;
out_data[i] = x_data[i] + y_data[i]; int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = dout + offset;
for (int k = 0; k < num; ++k) {
dout_ptr[k] = din_ptr[k] + diny_data;
}
}
} }
} }
TEST(elementwise_add, compute) { TEST(elementwise_add, compute) {
ElementwiseAddCompute elementwise_add; ElementwiseAddCompute elementwise_add;
operators::ElementwiseParam param; operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref;
lite::Tensor x, y, out, out_ref; for (auto n : {1, 3, 4, 11}) {
x.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5}))); for (auto c : {1, 3, 4, 11}) {
y.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5}))); for (auto h : {1, 3, 4, 11}) {
out.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5}))); for (auto w : {1, 3, 4, 11}) {
out_ref.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5}))); for (auto axis : {-1, 0, 1, 2, 3}) {
for (auto yd{{n},
{c},
{h},
{w},
{n, c},
{c, h},
{h, w},
{n, c, h},
{c, h, w},
{n, c, h, w}}) {
auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
auto y_dim = DDim(std::vector<int64_t>(yd));
int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis;
if (axis_t + y_dim.size() > 4) continue;
bool flag = false;
for (int i = 0; i < y_dim.size(); i++) {
if (x_dim[i + axis_t] != y_dim[i]) flag = true;
}
if (flag) continue;
x.Resize(x_dim);
y.Resize(y_dim);
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_data = x.mutable_data<float>(); auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>(); auto* output_data = output.mutable_data<float>();
auto* out_data = out.mutable_data<float>(); auto* output_ref_data = output_ref.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) { for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = y_data[i] = i; x_data[i] = i;
}
for (int i = 0; i < y.dims().production(); i++) {
y_data[i] = i;
} }
param.X = &x; param.X = &x;
param.Y = &y; param.Y = &y;
param.Out = &out; param.axis = axis;
elementwise_add.SetParam(param); param.Out = &output;
elementwise_add.Run(); softmax.SetParam(param);
softmax.Run();
param.Out = &out_ref; param.Out = &output_ref;
elementwise_add_compute_ref<float>(param); elementwise_add_compute_ref<float>(param);
for (int i = 0; i < out.dims().production(); i++) { for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册