提交 9ae05da3 编写于 作者: T tensor-tang

fix format

无相关合并请求
...@@ -172,3 +172,4 @@ add_subdirectory(model_parser) ...@@ -172,3 +172,4 @@ add_subdirectory(model_parser)
add_subdirectory(utils) add_subdirectory(utils)
add_subdirectory(api) add_subdirectory(api)
add_subdirectory(gen_code) add_subdirectory(gen_code)
...@@ -54,3 +54,4 @@ lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc ...@@ -54,3 +54,4 @@ lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
mir_passes mir_passes
${ops_lite} ${host_kernels} ${ops_lite} ${host_kernels}
ARM_DEPS ${arm_kernels}) ARM_DEPS ${arm_kernels})
add_subdirectory(math) add_subdirectory(math)
...@@ -34,3 +34,4 @@ cc_library(math_arm SRCS ...@@ -34,3 +34,4 @@ cc_library(math_arm SRCS
split.cc split.cc
DEPS ${lite_kernel_deps} eigen3) DEPS ${lite_kernel_deps} eigen3)
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/arm/math/saturate.h" #include "paddle/fluid/lite/arm/math/type_trans.h"
#include <arm_neon.h> #include <arm_neon.h>
#include <string.h> #include <string.h>
#include "paddle/fluid/lite/arm/math/saturate.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -23,563 +24,553 @@ namespace math { ...@@ -23,563 +24,553 @@ namespace math {
template <typename dtype> template <typename dtype>
void int32_to_dtype(const int* din, dtype* dout, const float* scale, void int32_to_dtype(const int* din, dtype* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size); int axis_size, int64_t outer_size, int64_t inner_size);
void fp32_to_int8(const float* din, signed char* dout, const float* scale, void fp32_to_int8(const float* din, signed char* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
int cnt = inner_size / 16;
int cnt = inner_size / 16; int remain = inner_size & 15;
int remain = inner_size & 15; int64_t loop_size = outer_size * axis_size;
long long loop_size = outer_size * axis_size;
#pragma omp parallel for #pragma omp parallel for
for (int j = 0; j < loop_size; ++j) { for (int j = 0; j < loop_size; ++j) {
float inv_scale = 1.f / scale[j % axis_size]; float inv_scale = 1.f / scale[j % axis_size];
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vscale = vdupq_n_f32(inv_scale); float32x4_t vscale = vdupq_n_f32(inv_scale);
float32x4_t vpoff = vdupq_n_f32(0.5f); float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f); float32x4_t vnoff = vdupq_n_f32(-0.5f);
const float* din_c = din + j * inner_size; const float* din_c = din + j * inner_size;
signed char* dout_c = dout + j * inner_size; signed char* dout_c = dout + j * inner_size;
if (cnt > 0) { if (cnt > 0) {
int cnt_loop = cnt; int cnt_loop = cnt;
const float* din_ptr = din_c; const float* din_ptr = din_c;
signed char* dout_ptr = dout_c; signed char* dout_ptr = dout_c;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( asm volatile(
"ldp q0, q1, [%[in]], #32 \n" "ldp q0, q1, [%[in]], #32 \n"
"ldp q2, q3, [%[in]], #32 \n" "ldp q2, q3, [%[in]], #32 \n"
"0: \n" /* main loop */ "0: \n" /* main loop */
"fmul v4.4s, v0.4s, %[scale].4s \n" "fmul v4.4s, v0.4s, %[scale].4s \n"
"fmul v5.4s, v1.4s, %[scale].4s \n" "fmul v5.4s, v1.4s, %[scale].4s \n"
"fmul v6.4s, v2.4s, %[scale].4s \n" "fmul v6.4s, v2.4s, %[scale].4s \n"
"fmul v7.4s, v3.4s, %[scale].4s \n" "fmul v7.4s, v3.4s, %[scale].4s \n"
"ldp q0, q1, [%[in]], #32 \n" "ldp q0, q1, [%[in]], #32 \n"
"subs %[cnt], %[cnt], #1 \n" "subs %[cnt], %[cnt], #1 \n"
"FCVTAS v8.4s, v4.4s \n" "FCVTAS v8.4s, v4.4s \n"
"FCVTAS v9.4s, v5.4s \n" "FCVTAS v9.4s, v5.4s \n"
"FCVTAS v10.4s, v6.4s \n" "FCVTAS v10.4s, v6.4s \n"
"FCVTAS v11.4s, v7.4s \n" "FCVTAS v11.4s, v7.4s \n"
"ldp q2, q3, [%[in]], #32 \n" "ldp q2, q3, [%[in]], #32 \n"
"sqxtn v4.4h, v8.4s \n" "sqxtn v4.4h, v8.4s \n"
"sqxtn2 v4.8h, v9.4s \n" "sqxtn2 v4.8h, v9.4s \n"
"sqxtn v5.4h, v10.4s \n" "sqxtn v5.4h, v10.4s \n"
"sqxtn2 v5.8h, v11.4s \n" "sqxtn2 v5.8h, v11.4s \n"
"sqxtn v8.8b, v4.8h \n" "sqxtn v8.8b, v4.8h \n"
"sqxtn2 v8.16b, v5.8h \n" "sqxtn2 v8.16b, v5.8h \n"
"str q8, [%[out]], #16 \n" "str q8, [%[out]], #16 \n"
"bne 0b \n" "bne 0b \n"
: [in] "+r" (din_ptr), [out] "+r" (dout_ptr), [cnt] "+r" (cnt_loop) : [in] "+r"(din_ptr), [out] "+r"(dout_ptr), [cnt] "+r"(cnt_loop)
: [scale] "w" (vscale) : [scale] "w"(vscale)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11" : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
); "v11");
#else #else
asm volatile( asm volatile(
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"0: @ main loop\n" "0: @ main loop\n"
"vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" "vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q5, q4, q4 @ set offset, 0.5\n" "vand.i32 q5, q4, q4 @ set offset, 0.5\n"
"vand.i32 q6, q4, q4 @ set offset, 0.5\n" "vand.i32 q6, q4, q4 @ set offset, 0.5\n"
"vand.i32 q7, q4, q4 @ set offset, 0.5\n" "vand.i32 q7, q4, q4 @ set offset, 0.5\n"
"vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n" "vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n" "vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n"
"vcgt.f32 q10, q2, %q[vzero] @ get mask > 0, in2\n" "vcgt.f32 q10, q2, %q[vzero] @ get mask > 0, in2\n"
"vcgt.f32 q11, q3, %q[vzero] @ get mask > 0, in3\n" "vcgt.f32 q11, q3, %q[vzero] @ get mask > 0, in3\n"
"vbif.f32 q4, %q[vnoff], q8 @ get right offset\n" "vbif.f32 q4, %q[vnoff], q8 @ get right offset\n"
"vbif.f32 q5, %q[vnoff], q9 @ get right offset\n" "vbif.f32 q5, %q[vnoff], q9 @ get right offset\n"
"vbif.f32 q6, %q[vnoff], q10 @ get right offset\n" "vbif.f32 q6, %q[vnoff], q10 @ get right offset\n"
"vbif.f32 q7, %q[vnoff], q11 @ get right offset\n" "vbif.f32 q7, %q[vnoff], q11 @ get right offset\n"
"vmla.f32 q4, q0, %q[vscale] @ mul scale\n" "vmla.f32 q4, q0, %q[vscale] @ mul scale\n"
"vmla.f32 q5, q1, %q[vscale] @ mul scale\n" "vmla.f32 q5, q1, %q[vscale] @ mul scale\n"
"vmla.f32 q6, q2, %q[vscale] @ mul scale\n" "vmla.f32 q6, q2, %q[vscale] @ mul scale\n"
"vmla.f32 q7, q3, %q[vscale] @ mul scale\n" "vmla.f32 q7, q3, %q[vscale] @ mul scale\n"
"vcvt.s32.f32 q0, q4 @ cvt to int32\n" "vcvt.s32.f32 q0, q4 @ cvt to int32\n"
"vcvt.s32.f32 q1, q5 @ cvt to int32\n" "vcvt.s32.f32 q1, q5 @ cvt to int32\n"
"vcvt.s32.f32 q2, q6 @ cvt to int32\n" "vcvt.s32.f32 q2, q6 @ cvt to int32\n"
"vcvt.s32.f32 q3, q7 @ cvt to int32\n" "vcvt.s32.f32 q3, q7 @ cvt to int32\n"
"vqmovn.s32 d8, q0 @ cnt to int16\n" "vqmovn.s32 d8, q0 @ cnt to int16\n"
"vqmovn.s32 d9, q1 @ cnt to int16\n" "vqmovn.s32 d9, q1 @ cnt to int16\n"
"vqmovn.s32 d10, q2 @ cnt to int16\n" "vqmovn.s32 d10, q2 @ cnt to int16\n"
"vqmovn.s32 d11, q3 @ cnt to int16\n" "vqmovn.s32 d11, q3 @ cnt to int16\n"
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vqmovn.s16 d12, q4 @ cnt to int8\n" "vqmovn.s16 d12, q4 @ cnt to int8\n"
"vqmovn.s16 d13, q5 @ cnt to int8\n" "vqmovn.s16 d13, q5 @ cnt to int8\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"vst1.32 {d12-d13}, [%[dout]]! @ write to output\n" "vst1.32 {d12-d13}, [%[dout]]! @ write to output\n"
"subs %[cnt], #1 @ loop count -1\n" "subs %[cnt], #1 @ loop count -1\n"
"bne 0b @ to main loop\n" "bne 0b @ to main loop\n"
:[dout]"+r"(dout_ptr), [din]"+r"(din_ptr), [cnt]"+r"(cnt_loop) : [dout] "+r"(dout_ptr), [din] "+r"(din_ptr), [cnt] "+r"(cnt_loop)
:[vscale]"w"(vscale), [vpoff]"w"(vpoff), [vnoff]"w"(vnoff), [vzero]"w"(vzero) : [vscale] "w"(vscale), [vpoff] "w"(vpoff), [vnoff] "w"(vnoff),
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11" [vzero] "w"(vzero)
); : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11");
#endif #endif
}
const float* din_r = din_c + 16 * cnt;
signed char* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int8_t>(roundf(inv_scale * din_r[i]));
}
} }
const float* din_r = din_c + 16 * cnt;
signed char* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int8_t>(roundf(inv_scale * din_r[i]));
}
}
} }
void fp32_to_int16(const float* din, int16_t* dout, const float* scale, void fp32_to_int16(const float* din, int16_t* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
int cnt = inner_size / 8;
int cnt = inner_size / 8; int remain = inner_size & 7;
int remain = inner_size & 7; int64_t loop_size = outer_size * axis_size;
long long loop_size = outer_size * axis_size;
#pragma omp parallel for #pragma omp parallel for
for (int j = 0; j < loop_size; ++j) { for (int j = 0; j < loop_size; ++j) {
float inv_scale = 1.f / scale[j % axis_size]; float inv_scale = 1.f / scale[j % axis_size];
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vscale = vdupq_n_f32(inv_scale); float32x4_t vscale = vdupq_n_f32(inv_scale);
float32x4_t vpoff = vdupq_n_f32(0.5f); float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f); float32x4_t vnoff = vdupq_n_f32(-0.5f);
const float* din_c = din + j * inner_size; const float* din_c = din + j * inner_size;
int16_t* dout_c = dout + j * inner_size; int16_t* dout_c = dout + j * inner_size;
if (cnt > 0) { if (cnt > 0) {
int cnt_loop = cnt; int cnt_loop = cnt;
const float* din_ptr = din_c; const float* din_ptr = din_c;
int16_t* dout_ptr = dout_c; int16_t* dout_ptr = dout_c;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( asm volatile(
"ldp q0, q1, [%[in]], #32 \n" "ldp q0, q1, [%[in]], #32 \n"
"0: \n" /* main loop */ "0: \n" /* main loop */
"fmul v4.4s, v0.4s, %[scale].4s \n" "fmul v4.4s, v0.4s, %[scale].4s \n"
"fmul v5.4s, v1.4s, %[scale].4s \n" "fmul v5.4s, v1.4s, %[scale].4s \n"
"ldp q0, q1, [%[in]], #32 \n" "ldp q0, q1, [%[in]], #32 \n"
"subs %[cnt], %[cnt], #1 \n" "subs %[cnt], %[cnt], #1 \n"
"FCVTAS v8.4s, v4.4s \n" "FCVTAS v8.4s, v4.4s \n"
"FCVTAS v9.4s, v5.4s \n" "FCVTAS v9.4s, v5.4s \n"
"sqxtn v4.4h, v8.4s \n" "sqxtn v4.4h, v8.4s \n"
"sqxtn2 v4.8h, v9.4s \n" "sqxtn2 v4.8h, v9.4s \n"
"str q4, [%[out]], #16 \n" "str q4, [%[out]], #16 \n"
"bne 0b \n" "bne 0b \n"
: [in] "+r" (din_ptr), [out] "+r" (dout_ptr), [cnt] "+r" (cnt_loop) : [in] "+r"(din_ptr), [out] "+r"(dout_ptr), [cnt] "+r"(cnt_loop)
: [scale] "w" (vscale) : [scale] "w"(vscale)
: "v0", "v1", "v4", "v5", "v8", "v9" : "v0", "v1", "v4", "v5", "v8", "v9");
);
#else #else
asm volatile( asm volatile(
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"0: @ main loop\n" "0: @ main loop\n"
"vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" "vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q5, q4, q4 @ set offset, 0.5\n" "vand.i32 q5, q4, q4 @ set offset, 0.5\n"
"vand.i32 q6, q4, q4 @ set offset, 0.5\n" "vand.i32 q6, q4, q4 @ set offset, 0.5\n"
"vand.i32 q7, q4, q4 @ set offset, 0.5\n" "vand.i32 q7, q4, q4 @ set offset, 0.5\n"
"vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n" "vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n" "vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n"
"vbif.f32 q4, %q[vnoff], q8 @ get right offset\n" "vbif.f32 q4, %q[vnoff], q8 @ get right offset\n"
"vbif.f32 q5, %q[vnoff], q9 @ get right offset\n" "vbif.f32 q5, %q[vnoff], q9 @ get right offset\n"
"vmla.f32 q4, q0, %q[vscale] @ mul scale\n" "vmla.f32 q4, q0, %q[vscale] @ mul scale\n"
"vmla.f32 q5, q1, %q[vscale] @ mul scale\n" "vmla.f32 q5, q1, %q[vscale] @ mul scale\n"
"vcvt.s32.f32 q0, q4 @ cvt to int32\n" "vcvt.s32.f32 q0, q4 @ cvt to int32\n"
"vcvt.s32.f32 q1, q5 @ cvt to int32\n" "vcvt.s32.f32 q1, q5 @ cvt to int32\n"
"vqmovn.s32 d8, q0 @ cnt to int16\n" "vqmovn.s32 d8, q0 @ cnt to int16\n"
"vqmovn.s32 d9, q1 @ cnt to int16\n" "vqmovn.s32 d9, q1 @ cnt to int16\n"
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vst1.32 {d8-d9}, [%[dout]]! @ write to output\n" "vst1.32 {d8-d9}, [%[dout]]! @ write to output\n"
"subs %[cnt], #1 @ loop count -1\n" "subs %[cnt], #1 @ loop count -1\n"
"bne 0b @ to main loop\n" "bne 0b @ to main loop\n"
:[dout]"+r"(dout_ptr), [din]"+r"(din_ptr), [cnt]"+r"(cnt_loop) : [dout] "+r"(dout_ptr), [din] "+r"(din_ptr), [cnt] "+r"(cnt_loop)
:[vscale]"w"(vscale), [vpoff]"w"(vpoff), [vnoff]"w"(vnoff), [vzero]"w"(vzero) : [vscale] "w"(vscale), [vpoff] "w"(vpoff), [vnoff] "w"(vnoff),
:"q0", "q1", "q4", "q5", "q6", "q7", "q8", "q9" [vzero] "w"(vzero)
); : "q0", "q1", "q4", "q5", "q6", "q7", "q8", "q9");
#endif #endif
}
const float* din_r = din_c + 8 * cnt;
int16_t* dout_r = dout_c + 8 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int16_t>(roundf(inv_scale * din_r[i]));
}
} }
const float* din_r = din_c + 8 * cnt;
int16_t* dout_r = dout_c + 8 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int16_t>(roundf(inv_scale * din_r[i]));
}
}
} }
void int8_to_fp32(const signed char* in, float* out, const float* scale, void int8_to_fp32(const signed char* in, float* out, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
int cnt = inner_size / 16;
int cnt = inner_size / 16; int remain = inner_size & 15;
int remain = inner_size & 15; int64_t loop_size = axis_size * outer_size;
long long loop_size = axis_size * outer_size;
#pragma omp parallel for #pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) { for (int64_t n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size]; float in_scale = scale[n % axis_size];
const signed char* din_c = in + n * inner_size; const signed char* din_c = in + n * inner_size;
float* dout_c = out + n * inner_size; float* dout_c = out + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale); float32x4_t vscale = vdupq_n_f32(in_scale);
if (cnt > 0) { if (cnt > 0) {
int loop = cnt; int loop = cnt;
const signed char* din_ptr = din_c; const signed char* din_ptr = din_c;
float* dout_ptr = dout_c; float* dout_ptr = dout_c;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( asm volatile(
"ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/ "ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/
"0: \n" /* main loop */ "0: \n" /* main loop */
"sshll v2.8h, v0.8b, #0 \n" /* trans to int16*/ "sshll v2.8h, v0.8b, #0 \n" /* trans to int16*/
"sshll v3.8h, v1.8b, #0 \n" /* trans to int16*/ "sshll v3.8h, v1.8b, #0 \n" /* trans to int16*/
"sshll v4.4s, v2.4h, #0 \n" /* trans to int32*/ "sshll v4.4s, v2.4h, #0 \n" /* trans to int32*/
"sshll2 v5.4s, v2.8h, #0 \n" /* trans to int32*/ "sshll2 v5.4s, v2.8h, #0 \n" /* trans to int32*/
"sshll v6.4s, v3.4h, #0 \n" /* trans to int32*/ "sshll v6.4s, v3.4h, #0 \n" /* trans to int32*/
"sshll2 v7.4s, v3.8h, #0 \n" /* trans to int32*/ "sshll2 v7.4s, v3.8h, #0 \n" /* trans to int32*/
"ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/ "ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/
"scvtf v8.4s, v4.4s \n" /* trans to fp32*/ "scvtf v8.4s, v4.4s \n" /* trans to fp32*/
"scvtf v9.4s, v5.4s \n" /* trans to fp32*/ "scvtf v9.4s, v5.4s \n" /* trans to fp32*/
"scvtf v10.4s, v6.4s \n" /* trans to fp32*/ "scvtf v10.4s, v6.4s \n" /* trans to fp32*/
"scvtf v11.4s, v7.4s \n" /* trans to fp32*/ "scvtf v11.4s, v7.4s \n" /* trans to fp32*/
"subs %[loop], %[loop], #1 \n" "subs %[loop], %[loop], #1 \n"
"fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/ "fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/
"fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/ "fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/
"fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/ "fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/
"fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/ "fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/
"stp q4, q5, [%[out]], #32 \n" /* write to memory*/ "stp q4, q5, [%[out]], #32 \n" /* write to memory*/
"stp q6, q7, [%[out]], #32 \n" /* write to memory*/ "stp q6, q7, [%[out]], #32 \n" /* write to memory*/
"bne 0b \n" "bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr)
:[scale] "w" (vscale) : [scale] "w"(vscale)
:"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11" : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
); "v11");
#else #else
asm volatile( asm volatile(
"vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n" "vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n"
"0: @ main loop\n" "0: @ main loop\n"
"vmovl.s8 q2, d0 @ trans to int16\n" "vmovl.s8 q2, d0 @ trans to int16\n"
"vmovl.s8 q3, d1 @ trans to int16\n" "vmovl.s8 q3, d1 @ trans to int16\n"
"vmovl.s16 q4, d4 @ trans to int32\n" "vmovl.s16 q4, d4 @ trans to int32\n"
"vmovl.s16 q5, d5 @ trans to int32\n" "vmovl.s16 q5, d5 @ trans to int32\n"
"vmovl.s16 q6, d6 @ trans to int32\n" "vmovl.s16 q6, d6 @ trans to int32\n"
"vmovl.s16 q7, d7 @ trans to int32\n" "vmovl.s16 q7, d7 @ trans to int32\n"
"vcvt.f32.s32 q0, q4 @ trans to fp32\n" "vcvt.f32.s32 q0, q4 @ trans to fp32\n"
"vcvt.f32.s32 q1, q5 @ trans to fp32\n" "vcvt.f32.s32 q1, q5 @ trans to fp32\n"
"vcvt.f32.s32 q2, q6 @ trans to fp32\n" "vcvt.f32.s32 q2, q6 @ trans to fp32\n"
"vcvt.f32.s32 q3, q7 @ trans to fp32\n" "vcvt.f32.s32 q3, q7 @ trans to fp32\n"
"vmul.f32 q4, q0, %q[scale] @ mul with scale\n" "vmul.f32 q4, q0, %q[scale] @ mul with scale\n"
"vmul.f32 q5, q1, %q[scale] @ mul with scale\n" "vmul.f32 q5, q1, %q[scale] @ mul with scale\n"
"vmul.f32 q6, q2, %q[scale] @ mul with scale\n" "vmul.f32 q6, q2, %q[scale] @ mul with scale\n"
"vmul.f32 q7, q3, %q[scale] @ mul with scale\n" "vmul.f32 q7, q3, %q[scale] @ mul with scale\n"
"vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n" "vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n"
"subs %[loop], #1 \n" "subs %[loop], #1 \n"
"vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n" "vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n"
"vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n" "vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n"
"bne 0b \n" "bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr)
:[scale] "w" (vscale) : [scale] "w"(vscale)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7" : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
); #endif // __aarch64__
#endif //__aarch64__ }
} const signed char* din_r = din_c + 16 * cnt;
const signed char* din_r = din_c + 16 * cnt; float* dout_r = dout_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt; for (int i = 0; i < remain; ++i) {
for (int i = 0; i < remain; ++i) { dout_r[i] = in_scale * din_r[i];
dout_r[i] = in_scale * din_r[i];
}
} }
}
} }
void int16_to_fp32(const short* in, float* out, const float* scale, void int16_to_fp32(const int16_t* in, float* out, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
int cnt = inner_size / 16;
int cnt = inner_size / 16; int remain = inner_size & 15;
int remain = inner_size & 15; int64_t loop_size = axis_size * outer_size;
long long loop_size = axis_size * outer_size;
#pragma omp parallel for #pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) { for (int64_t n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size]; float in_scale = scale[n % axis_size];
const short* din_c = in + n * inner_size; const int16_t* din_c = in + n * inner_size;
float* dout_c = out + n * inner_size; float* dout_c = out + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale); float32x4_t vscale = vdupq_n_f32(in_scale);
if (cnt > 0) { if (cnt > 0) {
int loop = cnt; int loop = cnt;
const short* din_ptr = din_c; const int16_t* din_ptr = din_c;
float* dout_ptr = dout_c; float* dout_ptr = dout_c;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( asm volatile(
"ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/ "ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/
"0: \n" /* main loop */ "0: \n" /* main loop */
"sshll v4.4s, v0.4h, #0 \n" /* trans to int32*/ "sshll v4.4s, v0.4h, #0 \n" /* trans to int32*/
"sshll2 v5.4s, v0.8h, #0 \n" /* trans to int32*/ "sshll2 v5.4s, v0.8h, #0 \n" /* trans to int32*/
"sshll v6.4s, v1.4h, #0 \n" /* trans to int32*/ "sshll v6.4s, v1.4h, #0 \n" /* trans to int32*/
"sshll2 v7.4s, v1.8h, #0 \n" /* trans to int32*/ "sshll2 v7.4s, v1.8h, #0 \n" /* trans to int32*/
"ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/ "ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/
"scvtf v8.4s, v4.4s \n" /* trans to fp32*/ "scvtf v8.4s, v4.4s \n" /* trans to fp32*/
"scvtf v9.4s, v5.4s \n" /* trans to fp32*/ "scvtf v9.4s, v5.4s \n" /* trans to fp32*/
"scvtf v10.4s, v6.4s \n" /* trans to fp32*/ "scvtf v10.4s, v6.4s \n" /* trans to fp32*/
"scvtf v11.4s, v7.4s \n" /* trans to fp32*/ "scvtf v11.4s, v7.4s \n" /* trans to fp32*/
"subs %[loop], %[loop], #1 \n" "subs %[loop], %[loop], #1 \n"
"fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/ "fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/
"fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/ "fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/
"fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/ "fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/
"fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/ "fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/
"stp q4, q5, [%[out]], #32 \n" /* write to memory*/ "stp q4, q5, [%[out]], #32 \n" /* write to memory*/
"stp q6, q7, [%[out]], #32 \n" /* write to memory*/ "stp q6, q7, [%[out]], #32 \n" /* write to memory*/
"bne 0b \n" "bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr)
:[scale] "w" (vscale) : [scale] "w"(vscale)
:"v0", "v1", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11" : "v0", "v1", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
);
#else #else
asm volatile( asm volatile(
"vld1.32 {d0-d3}, [%[in]]! @ load 16 int16\n" "vld1.32 {d0-d3}, [%[in]]! @ load 16 int16\n"
"0: @ main loop\n" "0: @ main loop\n"
"vmovl.s16 q4, d0 @ trans to int32\n" "vmovl.s16 q4, d0 @ trans to int32\n"
"vmovl.s16 q5, d1 @ trans to int32\n" "vmovl.s16 q5, d1 @ trans to int32\n"
"vmovl.s16 q6, d2 @ trans to int32\n" "vmovl.s16 q6, d2 @ trans to int32\n"
"vmovl.s16 q7, d3 @ trans to int32\n" "vmovl.s16 q7, d3 @ trans to int32\n"
"vcvt.f32.s32 q0, q4 @ trans to fp32\n" "vcvt.f32.s32 q0, q4 @ trans to fp32\n"
"vcvt.f32.s32 q1, q5 @ trans to fp32\n" "vcvt.f32.s32 q1, q5 @ trans to fp32\n"
"vcvt.f32.s32 q2, q6 @ trans to fp32\n" "vcvt.f32.s32 q2, q6 @ trans to fp32\n"
"vcvt.f32.s32 q3, q7 @ trans to fp32\n" "vcvt.f32.s32 q3, q7 @ trans to fp32\n"
"vmul.f32 q4, q0, %q[scale] @ mul with scale\n" "vmul.f32 q4, q0, %q[scale] @ mul with scale\n"
"vmul.f32 q5, q1, %q[scale] @ mul with scale\n" "vmul.f32 q5, q1, %q[scale] @ mul with scale\n"
"vmul.f32 q6, q2, %q[scale] @ mul with scale\n" "vmul.f32 q6, q2, %q[scale] @ mul with scale\n"
"vmul.f32 q7, q3, %q[scale] @ mul with scale\n" "vmul.f32 q7, q3, %q[scale] @ mul with scale\n"
"vld1.32 {d0-d3}, [%[in]]! @ load 16 int8\n" "vld1.32 {d0-d3}, [%[in]]! @ load 16 int8\n"
"subs %[loop], #1 \n" "subs %[loop], #1 \n"
"vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n" "vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n"
"vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n" "vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n"
"bne 0b \n" "bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr)
:[scale] "w" (vscale) : [scale] "w"(vscale)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7" : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
); #endif // __aarch64__
#endif //__aarch64__ }
} const int16_t* din_r = din_c + 16 * cnt;
const short* din_r = din_c + 16 * cnt; float* dout_r = dout_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt; for (int i = 0; i < remain; ++i) {
for (int i = 0; i < remain; ++i) { dout_r[i] = in_scale * din_r[i];
dout_r[i] = in_scale * din_r[i];
}
} }
}
} }
void int32_to_fp32(const int* din, float* dout, const float* scale, void int32_to_fp32(const int* din, float* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
int cnt = inner_size / 16; int cnt = inner_size / 16;
int remain = inner_size & 15; int remain = inner_size & 15;
long long loop_size = axis_size * outer_size; int64_t loop_size = axis_size * outer_size;
#pragma omp parallel for #pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) { for (int64_t n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size]; float in_scale = scale[n % axis_size];
const int* din_c = din + n * inner_size; const int* din_c = din + n * inner_size;
float* dout_c = dout + n * inner_size; float* dout_c = dout + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale); float32x4_t vscale = vdupq_n_f32(in_scale);
if (cnt > 0) { if (cnt > 0) {
int loop = cnt; int loop = cnt;
const int* din_ptr = din_c; const int* din_ptr = din_c;
float* dout_ptr = dout_c; float* dout_ptr = dout_c;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( asm volatile(
"ldp q0, q1, [%[in]], #32 \n" "ldp q0, q1, [%[in]], #32 \n"
"ldp q2, q3, [%[in]], #32 \n" "ldp q2, q3, [%[in]], #32 \n"
"0: \n" "0: \n"
"scvtf v4.4s, v0.4s \n" "scvtf v4.4s, v0.4s \n"
"scvtf v5.4s, v1.4s \n" "scvtf v5.4s, v1.4s \n"
"scvtf v6.4s, v2.4s \n" "scvtf v6.4s, v2.4s \n"
"scvtf v7.4s, v3.4s \n" "scvtf v7.4s, v3.4s \n"
"ldp q0, q1, [%[in]], #32 \n" "ldp q0, q1, [%[in]], #32 \n"
"fmul v8.4s, v4.4s, %[scale].4s \n" "fmul v8.4s, v4.4s, %[scale].4s \n"
"fmul v9.4s, v5.4s, %[scale].4s \n" "fmul v9.4s, v5.4s, %[scale].4s \n"
"fmul v10.4s, v6.4s, %[scale].4s \n" "fmul v10.4s, v6.4s, %[scale].4s \n"
"fmul v11.4s, v7.4s, %[scale].4s \n" "fmul v11.4s, v7.4s, %[scale].4s \n"
"ldp q2, q3, [%[in]], #32 \n" "ldp q2, q3, [%[in]], #32 \n"
"stp q8, q9, [%[out]], #32 \n" "stp q8, q9, [%[out]], #32 \n"
"stp q10, q11, [%[out]], #32 \n" "stp q10, q11, [%[out]], #32 \n"
"subs %[loop], %[loop], #1 \n" "subs %[loop], %[loop], #1 \n"
"bne 0b \n" "bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr)
:[scale] "w" (vscale) : [scale] "w"(vscale)
:"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11" : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
); "v11");
#else #else
asm volatile( asm volatile(
"vld1.s32 {d0-d3}, [%[in]]! \n" "vld1.s32 {d0-d3}, [%[in]]! \n"
"vld1.s32 {d4-d7}, [%[in]]! \n" "vld1.s32 {d4-d7}, [%[in]]! \n"
"0: \n" "0: \n"
"vcvt.f32.s32 q4, q0 \n" "vcvt.f32.s32 q4, q0 \n"
"vcvt.f32.s32 q5, q1 \n" "vcvt.f32.s32 q5, q1 \n"
"vcvt.f32.s32 q6, q2 \n" "vcvt.f32.s32 q6, q2 \n"
"vcvt.f32.s32 q7, q3 \n" "vcvt.f32.s32 q7, q3 \n"
"vld1.s32 {d0-d3}, [%[in]]! \n" "vld1.s32 {d0-d3}, [%[in]]! \n"
"vmul.f32 q8, q4, %q[scale] \n" "vmul.f32 q8, q4, %q[scale] \n"
"vmul.f32 q9, q5, %q[scale] \n" "vmul.f32 q9, q5, %q[scale] \n"
"vmul.f32 q10, q6, %q[scale] \n" "vmul.f32 q10, q6, %q[scale] \n"
"vmul.f32 q11, q7, %q[scale] \n" "vmul.f32 q11, q7, %q[scale] \n"
"vld1.s32 {d4-d7}, [%[in]]! \n" "vld1.s32 {d4-d7}, [%[in]]! \n"
"subs %[loop], #1 \n" "subs %[loop], #1 \n"
"vst1.f32 {d16-d19}, [%[out]]! \n" "vst1.f32 {d16-d19}, [%[out]]! \n"
"vst1.f32 {d20-d23}, [%[out]]! \n" "vst1.f32 {d20-d23}, [%[out]]! \n"
"bne 0b \n" "bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr)
:[scale] "w" (vscale) : [scale] "w"(vscale)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11" : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
); "q11");
#endif //__aarch64__ #endif // __aarch64__
}
const int* din_r = din_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = in_scale * din_r[i];
}
} }
const int* din_r = din_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = in_scale * din_r[i];
}
}
} }
void int32_to_int8(const int* din, signed char* dout, const float* scale, \ void int32_to_int8(const int* din, signed char* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
int cnt = inner_size / 16; int cnt = inner_size / 16;
int remain = inner_size & 15; int remain = inner_size & 15;
long long loop_size = outer_size * axis_size; int64_t loop_size = outer_size * axis_size;
#pragma omp parallel for #pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) { for (int64_t n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size]; float in_scale = scale[n % axis_size];
const int* din_c = din + n * inner_size; const int* din_c = din + n * inner_size;
signed char* dout_c = dout + n * inner_size; signed char* dout_c = dout + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale); float32x4_t vscale = vdupq_n_f32(in_scale);
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vpoff = vdupq_n_f32(0.5f); float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f); float32x4_t vnoff = vdupq_n_f32(-0.5f);
if (cnt > 0) { if (cnt > 0) {
int loop = cnt; int loop = cnt;
const int* din_ptr = din_c; const int* din_ptr = din_c;
signed char* dout_ptr = dout_c; signed char* dout_ptr = dout_c;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( asm volatile(
"0: \n" "0: \n"
"ld1 {v0.4s, v1.4s}, [%[in]], #32 \n" "ld1 {v0.4s, v1.4s}, [%[in]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[in]], #32 \n" "ld1 {v2.4s, v3.4s}, [%[in]], #32 \n"
"scvtf v4.4s, v0.4s \n" "scvtf v4.4s, v0.4s \n"
"scvtf v5.4s, v1.4s \n" "scvtf v5.4s, v1.4s \n"
"scvtf v6.4s, v2.4s \n" "scvtf v6.4s, v2.4s \n"
"scvtf v7.4s, v3.4s \n" "scvtf v7.4s, v3.4s \n"
"fmul v0.4s, v4.4s, %[scale].4s \n" "fmul v0.4s, v4.4s, %[scale].4s \n"
"fmul v1.4s, v5.4s, %[scale].4s \n" "fmul v1.4s, v5.4s, %[scale].4s \n"
"fmul v2.4s, v6.4s, %[scale].4s \n" "fmul v2.4s, v6.4s, %[scale].4s \n"
"fmul v3.4s, v7.4s, %[scale].4s \n" "fmul v3.4s, v7.4s, %[scale].4s \n"
"fcvtas v4.4s, v0.4s \n" "fcvtas v4.4s, v0.4s \n"
"fcvtas v5.4s, v1.4s \n" "fcvtas v5.4s, v1.4s \n"
"fcvtas v6.4s, v2.4s \n" "fcvtas v6.4s, v2.4s \n"
"fcvtas v7.4s, v3.4s \n" "fcvtas v7.4s, v3.4s \n"
"sqxtn v0.4h, v4.4s \n" "sqxtn v0.4h, v4.4s \n"
"sqxtn2 v0.8h, v5.4s \n" "sqxtn2 v0.8h, v5.4s \n"
"sqxtn v1.4h, v6.4s \n" "sqxtn v1.4h, v6.4s \n"
"sqxtn2 v1.8h, v7.4s \n" "sqxtn2 v1.8h, v7.4s \n"
"sqxtn v2.8b, v0.8h \n" "sqxtn v2.8b, v0.8h \n"
"sqxtn2 v2.16b, v1.8h \n" "sqxtn2 v2.16b, v1.8h \n"
"st1 {v2.16b}, [%[out]], #16 \n" "st1 {v2.16b}, [%[out]], #16 \n"
"subs %[loop], %[loop], #1 \n" "subs %[loop], %[loop], #1 \n"
"bne 0b \n" "bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr)
:[scale] "w" (vscale) : [scale] "w"(vscale)
:"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
);
#else #else
asm volatile( asm volatile(
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"0: @ main loop\n" "0: @ main loop\n"
"vcvt.f32.s32 q4, q0 @ cvt to float\n" "vcvt.f32.s32 q4, q0 @ cvt to float\n"
"vcvt.f32.s32 q5, q1 @ cvt to float\n" "vcvt.f32.s32 q5, q1 @ cvt to float\n"
"vcvt.f32.s32 q6, q2 @ cvt to float\n" "vcvt.f32.s32 q6, q2 @ cvt to float\n"
"vcvt.f32.s32 q7, q3 @ cvt to float\n" "vcvt.f32.s32 q7, q3 @ cvt to float\n"
"vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" "vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q1, q0, q0 @ set offset, 0.5\n" "vand.i32 q1, q0, q0 @ set offset, 0.5\n"
"vand.i32 q2, q0, q0 @ set offset, 0.5\n" "vand.i32 q2, q0, q0 @ set offset, 0.5\n"
"vand.i32 q3, q0, q0 @ set offset, 0.5\n" "vand.i32 q3, q0, q0 @ set offset, 0.5\n"
"vcgt.f32 q8, q4, %q[vzero] @ get mask > 0, in0\n" "vcgt.f32 q8, q4, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q5, %q[vzero] @ get mask > 0, in1\n" "vcgt.f32 q9, q5, %q[vzero] @ get mask > 0, in1\n"
"vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in2\n" "vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in2\n"
"vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in3\n" "vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in3\n"
"vbif.f32 q0, %q[vnoff], q8 @ get right offset\n" "vbif.f32 q0, %q[vnoff], q8 @ get right offset\n"
"vbif.f32 q1, %q[vnoff], q9 @ get right offset\n" "vbif.f32 q1, %q[vnoff], q9 @ get right offset\n"
"vbif.f32 q2, %q[vnoff], q10 @ get right offset\n" "vbif.f32 q2, %q[vnoff], q10 @ get right offset\n"
"vbif.f32 q3, %q[vnoff], q11 @ get right offset\n" "vbif.f32 q3, %q[vnoff], q11 @ get right offset\n"
"vmla.f32 q0, q4, %q[vscale] @ mul scale\n" "vmla.f32 q0, q4, %q[vscale] @ mul scale\n"
"vmla.f32 q1, q5, %q[vscale] @ mul scale\n" "vmla.f32 q1, q5, %q[vscale] @ mul scale\n"
"vmla.f32 q2, q6, %q[vscale] @ mul scale\n" "vmla.f32 q2, q6, %q[vscale] @ mul scale\n"
"vmla.f32 q3, q7, %q[vscale] @ mul scale\n" "vmla.f32 q3, q7, %q[vscale] @ mul scale\n"
"vcvt.s32.f32 q4, q0 @ cvt to int32\n" "vcvt.s32.f32 q4, q0 @ cvt to int32\n"
"vcvt.s32.f32 q5, q1 @ cvt to int32\n" "vcvt.s32.f32 q5, q1 @ cvt to int32\n"
"vcvt.s32.f32 q6, q2 @ cvt to int32\n" "vcvt.s32.f32 q6, q2 @ cvt to int32\n"
"vcvt.s32.f32 q7, q3 @ cvt to int32\n" "vcvt.s32.f32 q7, q3 @ cvt to int32\n"
"vqmovn.s32 d16, q4 @ cnt to int16\n" "vqmovn.s32 d16, q4 @ cnt to int16\n"
"vqmovn.s32 d17, q5 @ cnt to int16\n" "vqmovn.s32 d17, q5 @ cnt to int16\n"
"vqmovn.s32 d18, q6 @ cnt to int16\n" "vqmovn.s32 d18, q6 @ cnt to int16\n"
"vqmovn.s32 d19, q7 @ cnt to int16\n" "vqmovn.s32 d19, q7 @ cnt to int16\n"
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vqmovn.s16 d8, q8 @ cnt to int8\n" "vqmovn.s16 d8, q8 @ cnt to int8\n"
"vqmovn.s16 d9, q9 @ cnt to int8\n" "vqmovn.s16 d9, q9 @ cnt to int8\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"vst1.32 {d8-d9}, [%[dout]]! @ write to output\n" "vst1.32 {d8-d9}, [%[dout]]! @ write to output\n"
"subs %[loop], #1 @ loop count -1\n" "subs %[loop], #1 @ loop count -1\n"
"bne 0b @ to main loop\n" "bne 0b @ to main loop\n"
:[loop] "+r" (loop), [din] "+r" (din_ptr), [dout] "+r" (dout_ptr) : [loop] "+r"(loop), [din] "+r"(din_ptr), [dout] "+r"(dout_ptr)
:[vscale] "w" (vscale), [vzero] "w"(vzero), [vnoff] "w" (vnoff), [vpoff] "w" (vpoff) : [vscale] "w"(vscale), [vzero] "w"(vzero), [vnoff] "w"(vnoff),
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11" [vpoff] "w"(vpoff)
); : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
#endif //__aarch64__ "q11");
} #endif // __aarch64__
const int* din_r = din_c + 16 * cnt; }
int8_t* dout_r = dout_c + 16 * cnt; const int* din_r = din_c + 16 * cnt;
for (int i = 0; i < remain; ++i) { int8_t* dout_r = dout_c + 16 * cnt;
dout_r[i] = saturate_cast<int8_t>(roundf(in_scale * din_r[i])); for (int i = 0; i < remain; ++i) {
} dout_r[i] = saturate_cast<int8_t>(roundf(in_scale * din_r[i]));
} }
}
} }
void int32_to_int32(const int* din, int* dout, const float* scale, \ void int32_to_int32(const int* din, int* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
int size_all = outer_size * axis_size * inner_size; int size_all = outer_size * axis_size * inner_size;
memmove(dout, din, size_all*sizeof(int)); memmove(dout, din, size_all * sizeof(int));
} }
template <> template <>
void int32_to_dtype(const int* din, float* dout, const float* scale, void int32_to_dtype(const int* din, float* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
return int32_to_fp32(din, dout, scale, axis_size, outer_size, inner_size);
return int32_to_fp32(din, dout, scale, axis_size, outer_size, inner_size);
} }
template <> template <>
void int32_to_dtype(const int* din, signed char* dout, const float* scale, void int32_to_dtype(const int* din, signed char* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
return int32_to_int8(din, dout, scale, axis_size, outer_size, inner_size);
return int32_to_int8(din, dout, scale, axis_size, outer_size, inner_size);
} }
template <> template <>
void int32_to_dtype(const int* din, int* dout, const float* scale, void int32_to_dtype(const int* din, int* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) { int axis_size, int64_t outer_size, int64_t inner_size) {
return int32_to_int32(din, dout, scale, axis_size, outer_size, inner_size);
return int32_to_int32(din, dout, scale, axis_size, outer_size, inner_size);
} }
} // namespace math } // namespace math
......
...@@ -57,3 +57,4 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li ...@@ -57,3 +57,4 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite) lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite)
lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator) lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator)
...@@ -59,3 +59,4 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) ...@@ -59,3 +59,4 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
mir_passes compatible_pb_lite program_lite ${ops_lite}) mir_passes compatible_pb_lite program_lite ${ops_lite})
endif() endif()
...@@ -4,3 +4,4 @@ endif() ...@@ -4,3 +4,4 @@ endif()
lite_cc_library(basic_profiler_lite SRCS basic_profiler.cc) lite_cc_library(basic_profiler_lite SRCS basic_profiler.cc)
lite_cc_test(test_basic_profiler SRCS basic_profiler_test.cc DEPS basic_profiler_lite) lite_cc_test(test_basic_profiler SRCS basic_profiler_test.cc DEPS basic_profiler_lite)
...@@ -4,3 +4,4 @@ endif() ...@@ -4,3 +4,4 @@ endif()
nv_library(target_wrapper_cuda SRCS target_wrapper.cc) nv_library(target_wrapper_cuda SRCS target_wrapper.cc)
nv_library(cuda_blas_lite SRCS blas.cc) nv_library(cuda_blas_lite SRCS blas.cc)
...@@ -25,3 +25,4 @@ if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) ...@@ -25,3 +25,4 @@ if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
add_dependencies(__generated_code__ test_gen_code_lite) add_dependencies(__generated_code__ test_gen_code_lite)
endif() endif()
cc_library(target_wrapper_host SRCS target_wrapper.cc) cc_library(target_wrapper_host SRCS target_wrapper.cc)
...@@ -5,3 +5,4 @@ add_subdirectory(arm) ...@@ -5,3 +5,4 @@ add_subdirectory(arm)
add_subdirectory(cuda) add_subdirectory(cuda)
add_subdirectory(x86) add_subdirectory(x86)
...@@ -40,3 +40,4 @@ set(arm_kernels ...@@ -40,3 +40,4 @@ set(arm_kernels
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
...@@ -9,3 +9,4 @@ cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${tensor_lite}) ...@@ -9,3 +9,4 @@ cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${tensor_lite})
nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda cuda_blas_lite) nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda cuda_blas_lite)
...@@ -13,3 +13,4 @@ set(host_kernels ...@@ -13,3 +13,4 @@ set(host_kernels
) )
set(host_kernels "${host_kernels}" CACHE GLOBAL "host kernels") set(host_kernels "${host_kernels}" CACHE GLOBAL "host kernels")
...@@ -35,3 +35,4 @@ set(x86_kernels ...@@ -35,3 +35,4 @@ set(x86_kernels
) )
set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels") set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels")
...@@ -27,3 +27,4 @@ lite_cc_test(test_op_desc_lite SRCS op_desc_test.cc DEPS cpp_op_desc_lite op_des ...@@ -27,3 +27,4 @@ lite_cc_test(test_op_desc_lite SRCS op_desc_test.cc DEPS cpp_op_desc_lite op_des
add_subdirectory(pb) add_subdirectory(pb)
add_subdirectory(cpp) add_subdirectory(cpp)
cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite) cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite)
cc_library(var_desc_lite SRCS var_desc.cc DEPS framework_proto_lite) cc_library(var_desc_lite SRCS var_desc.cc DEPS framework_proto_lite)
cc_library(op_desc_lite SRCS op_desc.cc DEPS framework_proto_lite) cc_library(op_desc_lite SRCS op_desc.cc DEPS framework_proto_lite)
...@@ -56,3 +56,4 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m ...@@ -56,3 +56,4 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite) lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
...@@ -88,3 +88,4 @@ RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple wheel ...@@ -88,3 +88,4 @@ RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple wheel
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pre-commit RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pre-commit
RUN apt-get autoremove -y && apt-get clean RUN apt-get autoremove -y && apt-get clean
RUN rm -rf /sdk-tools-linux-4333796.zip /tmp/android-ndk-r17c-linux-x86_64.zip /cmake-3.10.3-Linux-x86_64.tar.gz RUN rm -rf /sdk-tools-linux-4333796.zip /tmp/android-ndk-r17c-linux-x86_64.zip /cmake-3.10.3-Linux-x86_64.tar.gz
\ No newline at end of file
...@@ -283,3 +283,4 @@ function main { ...@@ -283,3 +283,4 @@ function main {
} }
main $@ main $@
...@@ -124,3 +124,4 @@ $ adb devices ...@@ -124,3 +124,4 @@ $ adb devices
List of devices attached List of devices attached
5cb00b6 device 5cb00b6 device
``` ```
...@@ -9,3 +9,4 @@ set(utils_DEPS glog) ...@@ -9,3 +9,4 @@ set(utils_DEPS glog)
lite_cc_test(test_varient SRCS varient_test.cc DEPS utils_lite) lite_cc_test(test_varient SRCS varient_test.cc DEPS utils_lite)
cc_library(any_lite SRCS any.cc) cc_library(any_lite SRCS any.cc)
cc_library(utils_lite SRCS cp_logging.cc string.cc DEPS ${utils_DEPS} any_lite) cc_library(utils_lite SRCS cp_logging.cc string.cc DEPS ${utils_DEPS} any_lite)
...@@ -4,3 +4,4 @@ endif() ...@@ -4,3 +4,4 @@ endif()
cc_library(target_wrapper_x86 SRCS target_wrapper.cc) cc_library(target_wrapper_x86 SRCS target_wrapper.cc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部