提交 c07b215d 编写于 作者: H hong19860320

Merge branch 'incubate/lite' of http://10.87.145.36/inference/paddlelite into hongming/arm-fix

...@@ -41,15 +41,15 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout, ...@@ -41,15 +41,15 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
float32x4_t diny2 = vld1q_f32(diny_ptr + 8); float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12); float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
float32x4_t vsum0 = vaddq_f32(dinx0, diny0); dinx0 = vaddq_f32(dinx0, diny0);
float32x4_t vsum1 = vaddq_f32(dinx1, diny1); dinx1 = vaddq_f32(dinx1, diny1);
float32x4_t vsum2 = vaddq_f32(dinx2, diny2); dinx2 = vaddq_f32(dinx2, diny2);
float32x4_t vsum3 = vaddq_f32(dinx3, diny3); dinx3 = vaddq_f32(dinx3, diny3);
vst1q_f32(dout_ptr, vsum0); vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, vsum1); vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, vsum2); vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, vsum3); vst1q_f32(dout_ptr + 12, dinx3);
} }
if (remain > 0) { if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4); const float* dinx_ptr = dinx + (cnt << 4);
...@@ -64,6 +64,69 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout, ...@@ -64,6 +64,69 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
} }
} }
template <>
void elementwise_add_axis<float>(const float* dinx, const float* diny,
float* dout, int batch, int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
din2 = vaddq_f32(din2, rb);
din3 = vaddq_f32(din3, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vaddq_f32(din0, rb);
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; p++) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
}
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -22,6 +22,10 @@ namespace math { ...@@ -22,6 +22,10 @@ namespace math {
template <typename T> template <typename T>
void elementwise_add(const T* dinx, const T* diny, T* dout, int num); void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_add_axis(const T* dinx, const T* diny, T* dout, int batch,
int channels, int num);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -100,7 +100,7 @@ void ConvCompute::Run() { ...@@ -100,7 +100,7 @@ void ConvCompute::Run() {
REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def) paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -108,7 +108,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, ...@@ -108,7 +108,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def) paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -25,8 +25,31 @@ void ElementwiseAddCompute::Run() { ...@@ -25,8 +25,31 @@ void ElementwiseAddCompute::Run() {
const float* x_data = param.X->data<float>(); const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>(); const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>(); float* out_data = param.Out->mutable_data<float>();
int n = param.X->dims().production(); int axis = param.axis;
// lite::arm::math::elementwise_add(x_data, y_data, out_data, n); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
if (x_dims.size() == y_dims.size()) {
lite::arm::math::elementwise_add(x_data, y_data, out_data,
x_dims.production());
} else {
int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
lite::arm::math::elementwise_add_axis(x_data, y_data, out_data, batch,
channels, num);
}
} }
} // namespace arm } // namespace arm
......
...@@ -41,40 +41,97 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) { ...@@ -41,40 +41,97 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
const dtype* x_data = param.X->data<const dtype>(); const dtype* x_data = param.X->data<const dtype>();
const dtype* y_data = param.Y->data<const dtype>(); const dtype* y_data = param.Y->data<const dtype>();
dtype* out_data = param.Out->mutable_data<dtype>(); dtype* out_data = param.Out->mutable_data<dtype>();
DDim dim = param.X->dims(); auto x_dims = param.X->dims();
ASSERT_EQ(dim.data(), param.Out->dims().data()); auto y_dims = param.Y->dims();
for (int i = 0; i < dim.production(); i++) { int axis = param.axis;
out_data[i] = x_data[i] + y_data[i]; if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
} }
} }
TEST(elementwise_add, compute) { TEST(elementwise_add, compute) {
ElementwiseAddCompute elementwise_add; ElementwiseAddCompute elementwise_add;
operators::ElementwiseParam param; operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref;
lite::Tensor x, y, out, out_ref; for (auto n : {1, 3, 4, 11}) {
x.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5}))); for (auto c : {1, 3, 4, 11}) {
y.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5}))); for (auto h : {1, 3, 4, 11}) {
out.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5}))); for (auto w : {1, 3, 4, 11}) {
out_ref.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5}))); for (auto axis : {-1, 0, 1, 2, 3}) {
auto* x_data = x.mutable_data<float>(); for (auto yd :
auto* y_data = y.mutable_data<float>(); {std::vector<int64_t>({n}), std::vector<int64_t>({c}),
auto* out_data = out.mutable_data<float>(); std::vector<int64_t>({h}), std::vector<int64_t>({w}),
auto* out_ref_data = out_ref.mutable_data<float>(); std::vector<int64_t>({n, c}), std::vector<int64_t>({c, h}),
for (int i = 0; i < x.dims().production(); i++) { std::vector<int64_t>({h, w}), std::vector<int64_t>({n, c, h}),
x_data[i] = y_data[i] = i; std::vector<int64_t>({c, h, w}),
} std::vector<int64_t>({n, c, h, w})}) {
auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
auto y_dim = DDim(yd);
int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis;
param.X = &x; if (axis_t + y_dim.size() > 4) continue;
param.Y = &y; bool flag = false;
param.Out = &out; for (int i = 0; i < y_dim.size(); i++) {
elementwise_add.SetParam(param); if (x_dim[i + axis_t] != y_dim[i]) flag = true;
elementwise_add.Run(); }
if (flag) continue;
param.Out = &out_ref; x.Resize(x_dim);
elementwise_add_compute_ref<float>(param); y.Resize(y_dim);
for (int i = 0; i < out.dims().production(); i++) { output.Resize(x_dim);
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); output_ref.Resize(x_dim);
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x_dim.production(); i++) {
x_data[i] = i;
}
for (int i = 0; i < y_dim.production(); i++) {
y_data[i] = i;
}
param.X = &x;
param.Y = &y;
param.axis = axis;
param.Out = &output;
elementwise_add.SetParam(param);
elementwise_add.Run();
param.Out = &output_ref;
elementwise_add_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册