提交 8c572b39 编写于 作者: T tensor-tang

enable more conv impls

上级 aa55112f
......@@ -11,9 +11,20 @@ cc_library(math_arm SRCS
scale.cc
elementwise.cc
sgemv.cc
type_trans.cpp
conv_impl.cc
conv_direct_3x3s1.cc
conv_direct_3x3s2.cc
conv_direct.cc
conv_depthwise_3x3_int7.cc
conv_depthwise_3x3_int8.cc
conv_depthwise_5x5s1_int8.cc
conv_depthwise_3x3p0.cc
conv_depthwise_3x3p1.cc
conv_depthwise_5x5s1.cc
conv_depthwise_5x5s2.cc
conv_depthwise.cc
conv_gemmlike.cc
conv_winograd_3x3.cc
DEPS ${lite_kernel_deps} eigen3)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/saturate.h"
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename dtype>
void int32_to_dtype(const int* din, dtype* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size);
void fp32_to_int8(const float* din, signed char* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = outer_size * axis_size;
#pragma omp parallel for
for (int j = 0; j < loop_size; ++j) {
float inv_scale = 1.f / scale[j % axis_size];
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vscale = vdupq_n_f32(inv_scale);
float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f);
const float* din_c = din + j * inner_size;
signed char* dout_c = dout + j * inner_size;
if (cnt > 0) {
int cnt_loop = cnt;
const float* din_ptr = din_c;
signed char* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[in]], #32 \n"
"ldp q2, q3, [%[in]], #32 \n"
"0: \n" /* main loop */
"fmul v4.4s, v0.4s, %[scale].4s \n"
"fmul v5.4s, v1.4s, %[scale].4s \n"
"fmul v6.4s, v2.4s, %[scale].4s \n"
"fmul v7.4s, v3.4s, %[scale].4s \n"
"ldp q0, q1, [%[in]], #32 \n"
"subs %[cnt], %[cnt], #1 \n"
"FCVTAS v8.4s, v4.4s \n"
"FCVTAS v9.4s, v5.4s \n"
"FCVTAS v10.4s, v6.4s \n"
"FCVTAS v11.4s, v7.4s \n"
"ldp q2, q3, [%[in]], #32 \n"
"sqxtn v4.4h, v8.4s \n"
"sqxtn2 v4.8h, v9.4s \n"
"sqxtn v5.4h, v10.4s \n"
"sqxtn2 v5.8h, v11.4s \n"
"sqxtn v8.8b, v4.8h \n"
"sqxtn2 v8.16b, v5.8h \n"
"str q8, [%[out]], #16 \n"
"bne 0b \n"
: [in] "+r" (din_ptr), [out] "+r" (dout_ptr), [cnt] "+r" (cnt_loop)
: [scale] "w" (vscale)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"0: @ main loop\n"
"vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q5, q4, q4 @ set offset, 0.5\n"
"vand.i32 q6, q4, q4 @ set offset, 0.5\n"
"vand.i32 q7, q4, q4 @ set offset, 0.5\n"
"vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n"
"vcgt.f32 q10, q2, %q[vzero] @ get mask > 0, in2\n"
"vcgt.f32 q11, q3, %q[vzero] @ get mask > 0, in3\n"
"vbif.f32 q4, %q[vnoff], q8 @ get right offset\n"
"vbif.f32 q5, %q[vnoff], q9 @ get right offset\n"
"vbif.f32 q6, %q[vnoff], q10 @ get right offset\n"
"vbif.f32 q7, %q[vnoff], q11 @ get right offset\n"
"vmla.f32 q4, q0, %q[vscale] @ mul scale\n"
"vmla.f32 q5, q1, %q[vscale] @ mul scale\n"
"vmla.f32 q6, q2, %q[vscale] @ mul scale\n"
"vmla.f32 q7, q3, %q[vscale] @ mul scale\n"
"vcvt.s32.f32 q0, q4 @ cvt to int32\n"
"vcvt.s32.f32 q1, q5 @ cvt to int32\n"
"vcvt.s32.f32 q2, q6 @ cvt to int32\n"
"vcvt.s32.f32 q3, q7 @ cvt to int32\n"
"vqmovn.s32 d8, q0 @ cnt to int16\n"
"vqmovn.s32 d9, q1 @ cnt to int16\n"
"vqmovn.s32 d10, q2 @ cnt to int16\n"
"vqmovn.s32 d11, q3 @ cnt to int16\n"
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vqmovn.s16 d12, q4 @ cnt to int8\n"
"vqmovn.s16 d13, q5 @ cnt to int8\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"vst1.32 {d12-d13}, [%[dout]]! @ write to output\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 0b @ to main loop\n"
:[dout]"+r"(dout_ptr), [din]"+r"(din_ptr), [cnt]"+r"(cnt_loop)
:[vscale]"w"(vscale), [vpoff]"w"(vpoff), [vnoff]"w"(vnoff), [vzero]"w"(vzero)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"
);
#endif
}
const float* din_r = din_c + 16 * cnt;
signed char* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int8_t>(roundf(inv_scale * din_r[i]));
}
}
}
void fp32_to_int16(const float* din, int16_t* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 8;
int remain = inner_size & 7;
long long loop_size = outer_size * axis_size;
#pragma omp parallel for
for (int j = 0; j < loop_size; ++j) {
float inv_scale = 1.f / scale[j % axis_size];
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vscale = vdupq_n_f32(inv_scale);
float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f);
const float* din_c = din + j * inner_size;
int16_t* dout_c = dout + j * inner_size;
if (cnt > 0) {
int cnt_loop = cnt;
const float* din_ptr = din_c;
int16_t* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[in]], #32 \n"
"0: \n" /* main loop */
"fmul v4.4s, v0.4s, %[scale].4s \n"
"fmul v5.4s, v1.4s, %[scale].4s \n"
"ldp q0, q1, [%[in]], #32 \n"
"subs %[cnt], %[cnt], #1 \n"
"FCVTAS v8.4s, v4.4s \n"
"FCVTAS v9.4s, v5.4s \n"
"sqxtn v4.4h, v8.4s \n"
"sqxtn2 v4.8h, v9.4s \n"
"str q4, [%[out]], #16 \n"
"bne 0b \n"
: [in] "+r" (din_ptr), [out] "+r" (dout_ptr), [cnt] "+r" (cnt_loop)
: [scale] "w" (vscale)
: "v0", "v1", "v4", "v5", "v8", "v9"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"0: @ main loop\n"
"vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q5, q4, q4 @ set offset, 0.5\n"
"vand.i32 q6, q4, q4 @ set offset, 0.5\n"
"vand.i32 q7, q4, q4 @ set offset, 0.5\n"
"vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n"
"vbif.f32 q4, %q[vnoff], q8 @ get right offset\n"
"vbif.f32 q5, %q[vnoff], q9 @ get right offset\n"
"vmla.f32 q4, q0, %q[vscale] @ mul scale\n"
"vmla.f32 q5, q1, %q[vscale] @ mul scale\n"
"vcvt.s32.f32 q0, q4 @ cvt to int32\n"
"vcvt.s32.f32 q1, q5 @ cvt to int32\n"
"vqmovn.s32 d8, q0 @ cnt to int16\n"
"vqmovn.s32 d9, q1 @ cnt to int16\n"
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vst1.32 {d8-d9}, [%[dout]]! @ write to output\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 0b @ to main loop\n"
:[dout]"+r"(dout_ptr), [din]"+r"(din_ptr), [cnt]"+r"(cnt_loop)
:[vscale]"w"(vscale), [vpoff]"w"(vpoff), [vnoff]"w"(vnoff), [vzero]"w"(vzero)
:"q0", "q1", "q4", "q5", "q6", "q7", "q8", "q9"
);
#endif
}
const float* din_r = din_c + 8 * cnt;
int16_t* dout_r = dout_c + 8 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int16_t>(roundf(inv_scale * din_r[i]));
}
}
}
void int8_to_fp32(const signed char* in, float* out, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = axis_size * outer_size;
#pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size];
const signed char* din_c = in + n * inner_size;
float* dout_c = out + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale);
if (cnt > 0) {
int loop = cnt;
const signed char* din_ptr = din_c;
float* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/
"0: \n" /* main loop */
"sshll v2.8h, v0.8b, #0 \n" /* trans to int16*/
"sshll v3.8h, v1.8b, #0 \n" /* trans to int16*/
"sshll v4.4s, v2.4h, #0 \n" /* trans to int32*/
"sshll2 v5.4s, v2.8h, #0 \n" /* trans to int32*/
"sshll v6.4s, v3.4h, #0 \n" /* trans to int32*/
"sshll2 v7.4s, v3.8h, #0 \n" /* trans to int32*/
"ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/
"scvtf v8.4s, v4.4s \n" /* trans to fp32*/
"scvtf v9.4s, v5.4s \n" /* trans to fp32*/
"scvtf v10.4s, v6.4s \n" /* trans to fp32*/
"scvtf v11.4s, v7.4s \n" /* trans to fp32*/
"subs %[loop], %[loop], #1 \n"
"fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/
"fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/
"fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/
"fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/
"stp q4, q5, [%[out]], #32 \n" /* write to memory*/
"stp q6, q7, [%[out]], #32 \n" /* write to memory*/
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"
);
#else
asm volatile(
"vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n"
"0: @ main loop\n"
"vmovl.s8 q2, d0 @ trans to int16\n"
"vmovl.s8 q3, d1 @ trans to int16\n"
"vmovl.s16 q4, d4 @ trans to int32\n"
"vmovl.s16 q5, d5 @ trans to int32\n"
"vmovl.s16 q6, d6 @ trans to int32\n"
"vmovl.s16 q7, d7 @ trans to int32\n"
"vcvt.f32.s32 q0, q4 @ trans to fp32\n"
"vcvt.f32.s32 q1, q5 @ trans to fp32\n"
"vcvt.f32.s32 q2, q6 @ trans to fp32\n"
"vcvt.f32.s32 q3, q7 @ trans to fp32\n"
"vmul.f32 q4, q0, %q[scale] @ mul with scale\n"
"vmul.f32 q5, q1, %q[scale] @ mul with scale\n"
"vmul.f32 q6, q2, %q[scale] @ mul with scale\n"
"vmul.f32 q7, q3, %q[scale] @ mul with scale\n"
"vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n"
"subs %[loop], #1 \n"
"vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n"
"vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"
);
#endif //__aarch64__
}
const signed char* din_r = din_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = in_scale * din_r[i];
}
}
}
void int16_to_fp32(const short* in, float* out, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = axis_size * outer_size;
#pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size];
const short* din_c = in + n * inner_size;
float* dout_c = out + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale);
if (cnt > 0) {
int loop = cnt;
const short* din_ptr = din_c;
float* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/
"0: \n" /* main loop */
"sshll v4.4s, v0.4h, #0 \n" /* trans to int32*/
"sshll2 v5.4s, v0.8h, #0 \n" /* trans to int32*/
"sshll v6.4s, v1.4h, #0 \n" /* trans to int32*/
"sshll2 v7.4s, v1.8h, #0 \n" /* trans to int32*/
"ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/
"scvtf v8.4s, v4.4s \n" /* trans to fp32*/
"scvtf v9.4s, v5.4s \n" /* trans to fp32*/
"scvtf v10.4s, v6.4s \n" /* trans to fp32*/
"scvtf v11.4s, v7.4s \n" /* trans to fp32*/
"subs %[loop], %[loop], #1 \n"
"fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/
"fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/
"fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/
"fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/
"stp q4, q5, [%[out]], #32 \n" /* write to memory*/
"stp q6, q7, [%[out]], #32 \n" /* write to memory*/
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"v0", "v1", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[in]]! @ load 16 int16\n"
"0: @ main loop\n"
"vmovl.s16 q4, d0 @ trans to int32\n"
"vmovl.s16 q5, d1 @ trans to int32\n"
"vmovl.s16 q6, d2 @ trans to int32\n"
"vmovl.s16 q7, d3 @ trans to int32\n"
"vcvt.f32.s32 q0, q4 @ trans to fp32\n"
"vcvt.f32.s32 q1, q5 @ trans to fp32\n"
"vcvt.f32.s32 q2, q6 @ trans to fp32\n"
"vcvt.f32.s32 q3, q7 @ trans to fp32\n"
"vmul.f32 q4, q0, %q[scale] @ mul with scale\n"
"vmul.f32 q5, q1, %q[scale] @ mul with scale\n"
"vmul.f32 q6, q2, %q[scale] @ mul with scale\n"
"vmul.f32 q7, q3, %q[scale] @ mul with scale\n"
"vld1.32 {d0-d3}, [%[in]]! @ load 16 int8\n"
"subs %[loop], #1 \n"
"vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n"
"vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"
);
#endif //__aarch64__
}
const short* din_r = din_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = in_scale * din_r[i];
}
}
}
void int32_to_fp32(const int* din, float* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = axis_size * outer_size;
#pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size];
const int* din_c = din + n * inner_size;
float* dout_c = dout + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale);
if (cnt > 0) {
int loop = cnt;
const int* din_ptr = din_c;
float* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[in]], #32 \n"
"ldp q2, q3, [%[in]], #32 \n"
"0: \n"
"scvtf v4.4s, v0.4s \n"
"scvtf v5.4s, v1.4s \n"
"scvtf v6.4s, v2.4s \n"
"scvtf v7.4s, v3.4s \n"
"ldp q0, q1, [%[in]], #32 \n"
"fmul v8.4s, v4.4s, %[scale].4s \n"
"fmul v9.4s, v5.4s, %[scale].4s \n"
"fmul v10.4s, v6.4s, %[scale].4s \n"
"fmul v11.4s, v7.4s, %[scale].4s \n"
"ldp q2, q3, [%[in]], #32 \n"
"stp q8, q9, [%[out]], #32 \n"
"stp q10, q11, [%[out]], #32 \n"
"subs %[loop], %[loop], #1 \n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"
);
#else
asm volatile(
"vld1.s32 {d0-d3}, [%[in]]! \n"
"vld1.s32 {d4-d7}, [%[in]]! \n"
"0: \n"
"vcvt.f32.s32 q4, q0 \n"
"vcvt.f32.s32 q5, q1 \n"
"vcvt.f32.s32 q6, q2 \n"
"vcvt.f32.s32 q7, q3 \n"
"vld1.s32 {d0-d3}, [%[in]]! \n"
"vmul.f32 q8, q4, %q[scale] \n"
"vmul.f32 q9, q5, %q[scale] \n"
"vmul.f32 q10, q6, %q[scale] \n"
"vmul.f32 q11, q7, %q[scale] \n"
"vld1.s32 {d4-d7}, [%[in]]! \n"
"subs %[loop], #1 \n"
"vst1.f32 {d16-d19}, [%[out]]! \n"
"vst1.f32 {d20-d23}, [%[out]]! \n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"
);
#endif //__aarch64__
}
const int* din_r = din_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = in_scale * din_r[i];
}
}
}
void int32_to_int8(const int* din, signed char* dout, const float* scale, \
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = outer_size * axis_size;
#pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size];
const int* din_c = din + n * inner_size;
signed char* dout_c = dout + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f);
if (cnt > 0) {
int loop = cnt;
const int* din_ptr = din_c;
signed char* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"0: \n"
"ld1 {v0.4s, v1.4s}, [%[in]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[in]], #32 \n"
"scvtf v4.4s, v0.4s \n"
"scvtf v5.4s, v1.4s \n"
"scvtf v6.4s, v2.4s \n"
"scvtf v7.4s, v3.4s \n"
"fmul v0.4s, v4.4s, %[scale].4s \n"
"fmul v1.4s, v5.4s, %[scale].4s \n"
"fmul v2.4s, v6.4s, %[scale].4s \n"
"fmul v3.4s, v7.4s, %[scale].4s \n"
"fcvtas v4.4s, v0.4s \n"
"fcvtas v5.4s, v1.4s \n"
"fcvtas v6.4s, v2.4s \n"
"fcvtas v7.4s, v3.4s \n"
"sqxtn v0.4h, v4.4s \n"
"sqxtn2 v0.8h, v5.4s \n"
"sqxtn v1.4h, v6.4s \n"
"sqxtn2 v1.8h, v7.4s \n"
"sqxtn v2.8b, v0.8h \n"
"sqxtn2 v2.16b, v1.8h \n"
"st1 {v2.16b}, [%[out]], #16 \n"
"subs %[loop], %[loop], #1 \n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"0: @ main loop\n"
"vcvt.f32.s32 q4, q0 @ cvt to float\n"
"vcvt.f32.s32 q5, q1 @ cvt to float\n"
"vcvt.f32.s32 q6, q2 @ cvt to float\n"
"vcvt.f32.s32 q7, q3 @ cvt to float\n"
"vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q1, q0, q0 @ set offset, 0.5\n"
"vand.i32 q2, q0, q0 @ set offset, 0.5\n"
"vand.i32 q3, q0, q0 @ set offset, 0.5\n"
"vcgt.f32 q8, q4, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q5, %q[vzero] @ get mask > 0, in1\n"
"vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in2\n"
"vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in3\n"
"vbif.f32 q0, %q[vnoff], q8 @ get right offset\n"
"vbif.f32 q1, %q[vnoff], q9 @ get right offset\n"
"vbif.f32 q2, %q[vnoff], q10 @ get right offset\n"
"vbif.f32 q3, %q[vnoff], q11 @ get right offset\n"
"vmla.f32 q0, q4, %q[vscale] @ mul scale\n"
"vmla.f32 q1, q5, %q[vscale] @ mul scale\n"
"vmla.f32 q2, q6, %q[vscale] @ mul scale\n"
"vmla.f32 q3, q7, %q[vscale] @ mul scale\n"
"vcvt.s32.f32 q4, q0 @ cvt to int32\n"
"vcvt.s32.f32 q5, q1 @ cvt to int32\n"
"vcvt.s32.f32 q6, q2 @ cvt to int32\n"
"vcvt.s32.f32 q7, q3 @ cvt to int32\n"
"vqmovn.s32 d16, q4 @ cnt to int16\n"
"vqmovn.s32 d17, q5 @ cnt to int16\n"
"vqmovn.s32 d18, q6 @ cnt to int16\n"
"vqmovn.s32 d19, q7 @ cnt to int16\n"
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vqmovn.s16 d8, q8 @ cnt to int8\n"
"vqmovn.s16 d9, q9 @ cnt to int8\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"vst1.32 {d8-d9}, [%[dout]]! @ write to output\n"
"subs %[loop], #1 @ loop count -1\n"
"bne 0b @ to main loop\n"
:[loop] "+r" (loop), [din] "+r" (din_ptr), [dout] "+r" (dout_ptr)
:[vscale] "w" (vscale), [vzero] "w"(vzero), [vnoff] "w" (vnoff), [vpoff] "w" (vpoff)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"
);
#endif //__aarch64__
}
const int* din_r = din_c + 16 * cnt;
int8_t* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int8_t>(roundf(in_scale * din_r[i]));
}
}
}
void int32_to_int32(const int* din, int* dout, const float* scale, \
int axis_size, long long outer_size, long long inner_size) {
int size_all = outer_size * axis_size * inner_size;
memmove(dout, din, size_all*sizeof(int));
}
template <>
void int32_to_dtype(const int* din, float* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
return int32_to_fp32(din, dout, scale, axis_size, outer_size, inner_size);
}
template <>
void int32_to_dtype(const int* din, signed char* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
return int32_to_int8(din, dout, scale, axis_size, outer_size, inner_size);
}
template <>
void int32_to_dtype(const int* din, int* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
return int32_to_int32(din, dout, scale, axis_size, outer_size, inner_size);
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -14,6 +14,8 @@
#include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include "paddle/fluid/lite/arm/math/conv_direct.h"
#include "paddle/fluid/lite/arm/math/conv_depthwise.h"
#include "paddle/fluid/lite/arm/math/conv_gemmlike.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
......@@ -62,22 +64,26 @@ void ConvCompute::Run() {
// TODO(xxx): enable more
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
// dw conv impl
// impl_ = new lite::arm::math::prepackA<PRECISION(kFloat)>;
impl_ = new lite::arm::math::DepthwiseConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) {
if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) {
// winograd conv impl
// impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>;
LOG(FATAL) << "TODO!!! winograd conv";
} else {
// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking direct conv";
}
} else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal &&
no_dilation) {
// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
} else {
// impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>;
impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>;
LOG(INFO) << "invoking gemm like conv";
}
this->impl_->create(param, &ctx);
......@@ -98,7 +104,15 @@ PrecisionType ConvCompute::precision() const { return PRECISION(kFloat); }
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(conv, kARM, kFloat, kNCHW,
REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
......
......@@ -14,6 +14,8 @@
#include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
......@@ -23,9 +25,95 @@ namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void conv_compute_ref(const operators::ConvParam& param) {
auto input = param.x;
auto filter = param.filter;
auto output = param.output;
DDim input_dims = param.x->dims();
DDim filter_dims = param.filter->dims();
DDim output_dims = param.output->dims();
std::vector<int> paddings = param.paddings;
std::vector<int> strides = param.strides;
std::vector<int> dilations = param.dilations;
int groups = param.groups;
auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>();
auto filter_data = param.filter->mutable_data<float>();
const float* bias_data = nullptr;
if (param.bias != nullptr) {
bias_data = param.bias->mutable_data<float>();
}
bool flag_bias = bias_data != nullptr;
bool flag_relu = false; // TODO(hong19860320) param.relu
int num = input_dims[0];
int chout = output_dims[1];
int hout = output_dims[2];
int wout = output_dims[3];
int chin = input_dims[1];
int hin = input_dims[2];
int win = input_dims[3];
int out_c_group = chout / groups;
int in_c_group = chin / groups;
int stride_h = strides[0];
int stride_w = strides[1];
int dilation_h = dilations[0];
int dilation_w = dilations[1];
int padding_h = paddings[0];
int padding_w = paddings[1];
int kernel_h = filter_dims[2];
int kernel_w = filter_dims[3];
for (int n = 0; n < num; ++n) {
for (int g = 0; g < groups; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < hout; ++oh) {
for (int ow = 0; ow < wout; ++ow) {
int out_idx = n * groups * out_c_group * hout * wout +
g * out_c_group * hout * wout + oc * hout * wout +
oh * wout + ow;
output_data[out_idx] = 0.0f;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - padding_w + kw * (dilation_w);
int ih = oh * stride_h - padding_h + kh * (dilation_h);
if (iw < 0 || iw >= win) continue;
if (ih < 0 || ih >= hin) continue;
int iidx = n * chin * hin * win + g * in_c_group * hin * win +
ic * hin * win + ih * win + iw;
int widx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
output_data[out_idx] +=
(dtype)input_data[iidx] * (dtype)filter_data[widx];
}
}
}
output_data[out_idx] +=
flag_bias ? static_cast<float>(bias_data[g * out_c_group + oc])
: 0.f;
if (flag_relu) {
output_data[out_idx] =
output_data[out_idx] > 0.f ? output_data[out_idx] : 0.f;
}
}
}
}
}
}
}
TEST(conv_arm, retrive_op) {
auto conv =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("conv");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("conv2d");
ASSERT_FALSE(conv.empty());
ASSERT_TRUE(conv.front());
}
......@@ -36,8 +124,153 @@ TEST(conv_arm, init) {
ASSERT_EQ(conv.target(), TARGET(kARM));
}
TEST(conv_arm, compare_test) {
// TODO(xxx): add more compare
TEST(conv_arm, compute) {
ConvCompute conv;
operators::ConvParam param;
lite::Tensor input;
lite::Tensor filter;
lite::Tensor bias;
lite::Tensor output;
lite::Tensor output_ref;
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
conv.SetContext(std::move(ctx));
for (auto n : {1, 2}) {
for (auto chin : {3, 8, /*32, 128*/}) {
for (auto chout : {3, 8, /*32, 128*/}) {
for (auto hin : {7, 14, 28, /*56 , 112, 224, 512*/}) {
for (auto win : {7, 14, 28, /*56, 112, 224, 512*/}) {
for (auto flag_bias : {false , true}) {
for (auto flag_relu : {false , true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1 /*, 2*/}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1}) {
for (auto ks : {/*1, */3/*, 5*/}) {
int group = 1;
if (depthwise) { // depthwise conv ?
group = chin;
chout = chin;
// remove the follow code if
// all kernels are implemented.
if (ks == 5) {
stride = 2;
padding = 2;
}
}
// get input, filter and output shape
std::vector<int64_t> input_shape = {n, chin, hin,
win};
std::vector<int64_t> filter_shape = {
chout, chin / group, ks, ks};
std::vector<int64_t> output_shape({n, chout});
const int dkernel = dilation * (ks - 1) + 1;
output_shape.push_back(
(hin + 2 * padding - dkernel) / stride + 1);
output_shape.push_back(
(win + 2 * padding - dkernel) / stride + 1);
// resize input, filter and output
input.Resize(DDim(input_shape));
filter.Resize(DDim(filter_shape));
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] = i / 1000.0f;
}
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
// TODO(hong19860320) param.relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Run();
param.output = &output_ref;
conv_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i],
1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
#if 0
// for testing gemm like conv
int n = 1;
int chin = 8;
int chout = 8;
int hin = 14;
int win = 14;
int flag_bias = false;
int flag_relu = false;
int dilation = 1;
int stride = 1;
int padding = 1;
int ks = 5;
int group = 1;
// get input, filter and output shape
std::vector<int64_t> input_shape = {n, chin, hin, win};
std::vector<int64_t> filter_shape = {chout, chin / group, ks, ks};
std::vector<int64_t> output_shape({n, chout});
const int dkernel = dilation * (ks - 1) + 1;
output_shape.push_back((hin + 2 * padding - dkernel) / stride + 1);
output_shape.push_back((win + 2 * padding - dkernel) / stride + 1);
// resize input, filter and output
input.Resize(DDim(input_shape));
filter.Resize(DDim(filter_shape));
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] = i / 1000.0f;
}
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
// TODO(hong19860320) param.relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations = std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Run();
param.output = &output_ref;
conv_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-3);
}
#endif
}
} // namespace arm
......@@ -45,4 +278,5 @@ TEST(conv_arm, compare_test) {
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(conv, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
......@@ -73,4 +73,5 @@ bool ConvOpLite::InferShape() const {
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(conv, paddle::lite::operators::ConvOpLite);
REGISTER_LITE_OP(conv2d, paddle::lite::operators::ConvOpLite);
REGISTER_LITE_OP(depthwise_conv2d, paddle::lite::operators::ConvOpLite);
\ No newline at end of file
......@@ -41,27 +41,39 @@ class ConvOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto input = op_desc.Input("Input").front();
auto filter = op_desc.Input("Filter").front();
auto bias = op_desc.Input("Bias").front();
auto resid = op_desc.Input("ResidualData").front(); // maybe not used
auto out = op_desc.Output("Out").front();
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(filter)->GetMutable<lite::Tensor>();
param_.residualData = scope->FindVar(resid)->GetMutable<lite::Tensor>();
param_.bias = scope->FindVar(bias)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups");
param_.dilations = op_desc.GetAttr<std::vector<int>>("dilations");
// optional params
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_var = scope->FindVar(op_desc.Input("Bias").front());
if (bias_var != nullptr) {
param_.bias =
const_cast<lite::Tensor *>(&(bias_var->Get<lite::Tensor>()));
}
}
if (std::find(input_arg_names.begin(), input_arg_names.end(), "ResidualData") !=
input_arg_names.end()) {
auto residual_data_var = scope->FindVar(op_desc.Input("ResidualData").front());
if (residual_data_var != nullptr) {
param_.residualData =
const_cast<lite::Tensor *>(&(residual_data_var->Get<lite::Tensor>()));
}
}
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "conv"; }
std::string DebugString() const override { return "conv2d"; }
private:
mutable ConvParam param_;
......
......@@ -124,8 +124,8 @@ struct ConcatParam {
struct ConvParam {
lite::Tensor* x{};
lite::Tensor* filter{};
lite::Tensor* bias{};
lite::Tensor* residualData{};
lite::Tensor* bias{nullptr};
lite::Tensor* residualData{nullptr};
lite::Tensor* output{};
std::vector<int> strides{1, 1};
std::vector<int> paddings{0, 0};
......
......@@ -34,7 +34,7 @@ function cmake_arm {
function build {
file=$1
for _test in $(cat $file); do
make $_test -j$(expr $(nproc) - 2)
make $_test -j$(expr $(nproc))
done
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册