未验证 提交 81dffbe8 编写于 作者: X Xiaoyang LI 提交者: GitHub

fix bias quantize error && fix clang build error (#2049)

* fix gemm_int8, gemv-int8 and conv-int8 math function, add float bias

* change conv impl

* neon int8 kernel support float bias

* arm compute kernel support float bias

* add math_test target

* add tensor utils for testing, fix sgemm ut error

* add gemm_int8 unit test, support float bias

* fix build script

* add conv compute unit test for arm

* fix build script, test=develop

* fix fp32 dw conv3x3s1, test=develop

* add fp32 dw conv3x3s1, test=develop

* add armv7 fp32 dw conv3x3s1, test=develop

* add fp32 depthwise conv3x3s2, test=develop

* fix fp32 conv3x3 depthwise build error, test=develop

* fix gemm_like conv trans weights error, test=develop

* fix int8 depthwise conv3x3 error, test=develop

* turn on all test for arm fp32 conv, test=develop

* fix int8 conv1x1 error

* fix int8 direct conv3x3s1 error, test=develop

* fix int8 direct conv3x3s2, test=develop

* turn on all test for arm int8 conv, test=develop

* fix int8 fc error, change mobilenetv1-int8 ground-truth result to fluid, test=develop

* remove debug info, strip ut binary, test=develop

* fix conv compute error, test=develop

* change Init() to ReInitWhenNeeded(), test=develop

* fix code style, test=develop

* remote engine_test, test=develop

* fix building server tests error, test=develop

* fix sdot clang build error, test=develop

* fix sgemm ut timeout error, test=develop

* fix clang build error, test=develop

* turn off math basic test due to ci time out, test=develop

* fix conv_int8 ut error, test=develop
上级 71bb3188
...@@ -165,6 +165,11 @@ function(lite_cc_binary TARGET) ...@@ -165,6 +165,11 @@ function(lite_cc_binary TARGET)
) )
cc_binary(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS}) cc_binary(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS})
target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers) target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers)
# strip binary target to reduce size
add_custom_command(TARGET ${TARGET} POST_BUILD
COMMAND "${CMAKE_STRIP}" -s
"${TARGET}"
COMMENT "Strip debug symbols done on final executable file.")
# collect targets need to compile for lite # collect targets need to compile for lite
if (NOT args_EXCLUDE_COMPILE_DEPS) if (NOT args_EXCLUDE_COMPILE_DEPS)
add_dependencies(lite_compile_deps ${TARGET}) add_dependencies(lite_compile_deps ${TARGET})
...@@ -207,6 +212,11 @@ function(lite_cc_test TARGET) ...@@ -207,6 +212,11 @@ function(lite_cc_test TARGET)
HVY_DEPS ${args_HVY_DEPS} HVY_DEPS ${args_HVY_DEPS}
) )
_lite_cc_test(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ARGS ${args_ARGS}) _lite_cc_test(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ARGS ${args_ARGS})
# strip binary target to reduce size
add_custom_command(TARGET ${TARGET} POST_BUILD
COMMAND "${CMAKE_STRIP}" -s
"${TARGET}"
COMMENT "Strip debug symbols done on final executable file.")
target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers) target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers)
file(APPEND ${offline_test_registry_file} "${TARGET}\n") file(APPEND ${offline_test_registry_file} "${TARGET}\n")
......
...@@ -68,7 +68,6 @@ class LITE_API Predictor { ...@@ -68,7 +68,6 @@ class LITE_API Predictor {
GenRuntimeProgram(); GenRuntimeProgram();
} }
program_->Run(); program_->Run();
LOG(INFO) << "running";
} }
// Get offset-th col of feed inputs. // Get offset-th col of feed inputs.
......
...@@ -58,8 +58,9 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -58,8 +58,9 @@ void TestModel(const std::vector<Place>& valid_places,
std::vector<std::vector<float>> results; std::vector<std::vector<float>> results;
// i = 1 // i = 1
// ground truth result from fluid
results.emplace_back(std::vector<float>( results.emplace_back(std::vector<float>(
{0.000227548, 0.000262385, 0.000260347, 0.000293865, 0.00025008})); {0.0002451055, 0.0002585023, 0.0002659616, 0.0002823}));
auto* out = predictor.GetOutput(0); auto* out = predictor.GetOutput(0);
ASSERT_EQ(out->dims().size(), 2); ASSERT_EQ(out->dims().size(), 2);
ASSERT_EQ(out->dims()[0], 1); ASSERT_EQ(out->dims()[0], 1);
......
...@@ -6,6 +6,17 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) ...@@ -6,6 +6,17 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))
return() return()
endif() endif()
set(script_dir ${CMAKE_CURRENT_SOURCE_DIR}/../../../tools/)
message(STATUS "generating arm dotprod code")
find_package(PythonInterp REQUIRED)
execute_process(COMMAND ${PYTHON_EXECUTABLE} ${script_dir}/convert_arm_sdot_to_machine_code.py
"--input_file=${CMAKE_CURRENT_SOURCE_DIR}/dotprod/__gemm_sdot_meta__.h"
"--output_file=${CMAKE_CURRENT_SOURCE_DIR}/dotprod/gemm_sdot.h"
RESULT_VARIABLE gen_code_ret)
if (NOT ${gen_code_ret} STREQUAL "0")
message(FATAL_ERROR "generating dotprod code quit with error: ${gen_code_ret}")
endif ()
set(HAS_ARM_MATH_LIB_DIR OFF) set(HAS_ARM_MATH_LIB_DIR OFF)
# will search name as "libmath_arm.${os}.${abi}.${lang}.a" # will search name as "libmath_arm.${os}.${abi}.${lang}.a"
if(ARM_MATH_LIB_DIR AND EXISTS "${ARM_MATH_LIB_DIR}") if(ARM_MATH_LIB_DIR AND EXISTS "${ARM_MATH_LIB_DIR}")
...@@ -50,6 +61,25 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -50,6 +61,25 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
funcs.cc funcs.cc
packed_sgemm.cc packed_sgemm.cc
sgemm.cc sgemm.cc
gemm_prepacked_int8.cc
gemm_s8.cc
sgemv.cc
gemv_arm_int8.cc
conv3x3s1_direct_fp32.cc
conv3x3s2_direct_fp32.cc
conv3x3s1_depthwise_fp32.cc
conv3x3s2_depthwise_fp32.cc
conv3x3s1_direct_int8.cc
conv3x3s2_direct_int8.cc
conv3x3s1_depthwise_int8.cc
conv3x3s2_depthwise_int8.cc
conv5x5s1_depthwise_int8.cc
conv5x5s1_depthwise_fp32.cc
conv5x5s2_depthwise_fp32.cc
conv_depthwise_3x3p0.cc
conv_depthwise_3x3p1.cc
conv_winograd_3x3.cc
conv_impl.cc
softmax.cc softmax.cc
scale.cc scale.cc
pooling.cc pooling.cc
...@@ -57,32 +87,13 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -57,32 +87,13 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
lrn.cc lrn.cc
decode_bboxes.cc decode_bboxes.cc
concat.cc concat.cc
sgemv.cc
type_trans.cc type_trans.cc
box_coder.cc box_coder.cc
conv_impl.cc
conv_direct_3x3s1.cc
conv_direct_3x3s2.cc
conv_direct.cc
conv_depthwise_3x3_int8.cc
conv_depthwise_5x5s1_int8.cc
conv_depthwise_3x3p0.cc
conv_depthwise_3x3p1.cc
conv_depthwise_5x5s1.cc
conv_depthwise_5x5s2.cc
conv_depthwise.cc
conv_gemmlike.cc
conv_winograd_3x3.cc
conv_winograd.cc
split.cc split.cc
shuffle_channel.cc shuffle_channel.cc
activation.cc activation.cc
yolo_box.cc yolo_box.cc
dropout.cc dropout.cc
gemm_prepacked_int8.cc
gemv_arm_int8.cc
conv3x3s1_direct_int8.cc
conv3x3s2_direct_int8.cc
power.cc power.cc
interpolate.cc interpolate.cc
argmax.cc argmax.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/operators/op_params.h"
#ifdef ARM_WITH_OMP
#include <omp.h>
#endif
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void conv_3x3s1_depthwise_fp32(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx) {
int threads = ctx->threads();
const int pad_h = param.paddings[0];
const int pad_w = param.paddings[1];
const int out_c_block = 4;
const int out_h_kernel = 2;
const int out_w_kernel = 4;
const int win_ext = ow + 2;
const int ow_round = ROUNDUP(ow, 4);
const int win_round = ROUNDUP(win_ext, 4);
const int hin_round = oh + 2;
const int prein_size = win_round * hin_round * out_c_block;
auto workspace_size =
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
/// get workspace
float* ptr_zero = ctx->workspace_data<float>();
memset(ptr_zero, 0, sizeof(float) * win_round);
float* ptr_write = ptr_zero + win_round;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int ws = -pad_w;
int we = ws + win_round;
int hs = -pad_h;
int he = hs + hin_round;
int w_loop = ow_round / 4;
auto remain = w_loop * 4 - ow;
bool flag_remain = remain > 0;
remain = 4 - remain;
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc; c += out_c_block) {
#ifdef ARM_WITH_OMP
float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size;
#else
float* pre_din = ptr_write + ow_round;
#endif
/// const array size
float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT
prepack_input_nxwc4_dw(
din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero);
const float* weight_c = weights + c * 9; // kernel_w * kernel_h
float* dout_c00 = dout_batch + c * size_out_channel;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
float32x4_t vbias = vld1q_f32(bias_local);
#ifdef __aarch64__
float32x4_t w0 = vld1q_f32(weight_c); // w0, v23
float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24
float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25
float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26
float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27
float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28
float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29
float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30
float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31
#endif
for (int h = 0; h < oh; h += out_h_kernel) {
float* outc00 = dout_c00 + h * ow;
float* outc01 = outc00 + ow;
float* outc10 = outc00 + size_out_channel;
float* outc11 = outc10 + ow;
float* outc20 = outc10 + size_out_channel;
float* outc21 = outc20 + ow;
float* outc30 = outc20 + size_out_channel;
float* outc31 = outc30 + ow;
const float* inr0 = pre_din + h * row_len;
const float* inr1 = inr0 + row_len;
const float* inr2 = inr1 + row_len;
const float* inr3 = inr2 + row_len;
if (c + out_c_block > oc) {
switch (c + out_c_block - oc) {
case 3:
outc10 = ptr_write;
outc11 = ptr_write;
case 2:
outc20 = ptr_write;
outc21 = ptr_write;
case 1:
outc30 = ptr_write;
outc31 = ptr_write;
default:
break;
}
}
if (h + out_h_kernel > oh) {
outc01 = ptr_write;
outc11 = ptr_write;
outc21 = ptr_write;
outc31 = ptr_write;
}
auto c00 = outc00;
auto c01 = outc01;
auto c10 = outc10;
auto c11 = outc11;
auto c20 = outc20;
auto c21 = outc21;
auto c30 = outc30;
auto c31 = outc31;
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
// clang-format off
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/
"ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/
"ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/
"ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/
"ldp q4, q5, [%[inr0]]\n" /* load input r0*/
"ldp q10, q11, [%[inr1]]\n" /* load input r1*/
/* r0, r1, mul w0, get out r0, r1 */
"fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/
"fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/
"fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/
"fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/
"fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/
"fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/
"fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/
"fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/
/* r0, r1, mul w1, get out r0, r1 */
"fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/
"ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/
"fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/
"fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/
"fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/
"fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/
"fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/
"fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/
"fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/
/* r0, r1, mul w2, get out r0, r1 */
"fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/
"fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/
"ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/
"fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/
"fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/
"ldp q4, q5, [%[inr2]]\n" /* load input r2*/
"fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/
"fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/
"fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/
"fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/
/* r1, r2, mul w3, get out r0, r1 */
"fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/
"fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/
"fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/
"fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/
"fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/
"fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/
"fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/
"fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/
/* r1, r2, mul w4, get out r0, r1 */
"fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/
"ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/
"fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/
"fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/
"fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/
"fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/
"fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/
"fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/
/* r1, r2, mul w5, get out r0, r1 */
"fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/
"fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/
"ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/
"fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/
"fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/
"ldp q10, q11, [%[inr3]]\n" /* load input r3*/
"fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/
"fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/
"fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/
"fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/
/* r2, r3, mul w6, get out r0, r1 */
"fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/
"fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/
"fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/
"fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/
"fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/
"fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/
"fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/
"fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/
/* r2, r3, mul w7, get out r0, r1 */
"fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/
"fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/
"fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/
"fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/
"fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/
"fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/
"fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/
"fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/
/* r2, r3, mul w8, get out r0, r1 */
"fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/
"fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/
"fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/
"fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/
"fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/
"fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/
"fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/
"fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/
/* save result */
"stp q15, q16, [%[out]], #32\n"
"stp q17, q18, [%[out]], #32\n"
"stp q19, q20, [%[out]], #32\n"
"stp q21, q22, [%[out]]\n"
:[inr0] "+r"(inr0), [inr1] "+r"(inr1),
[inr2] "+r"(inr2), [inr3] "+r"(inr3),
[out]"+r"(out0)
:[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2),
[w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5),
[w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8)
: "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8", "v9", "v10", "v11", "v15",
"v16","v17","v18","v19","v20","v21","v22"
);
#else
asm volatile(
/* load weights */
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n"
/* load r0, r1 */
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n"
"vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n"
/* main loop */
"0: @ main loop\n"
/* mul r0 with w0, w1, w2, get out r0 */
"vmul.f32 q8, q5, q0 @ w0 * inr00\n"
"vmul.f32 q9, q5, q1 @ w0 * inr01\n"
"vmul.f32 q10, q5, q2 @ w0 * inr02\n"
"vmul.f32 q11, q5, q3 @ w0 * inr03\n"
"vmla.f32 q8, q6, q1 @ w1 * inr01\n"
"vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n"
"vmla.f32 q9, q6, q2 @ w1 * inr02\n"
"vmla.f32 q10, q6, q3 @ w1 * inr03\n"
"vmla.f32 q11, q6, q0 @ w1 * inr04\n"
"vmla.f32 q8, q7, q2 @ w2 * inr02\n"
"vmla.f32 q9, q7, q3 @ w2 * inr03\n"
"vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n"
"vmla.f32 q10, q7, q0 @ w2 * inr04\n"
"vmla.f32 q11, q7, q1 @ w2 * inr05\n"
"vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n"
"vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n"
/* mul r1 with w0-w5, get out r0, r1 */
"vmul.f32 q12, q5, q2 @ w0 * inr10\n"
"vmul.f32 q13, q5, q3 @ w0 * inr11\n"
"vmul.f32 q14, q5, q0 @ w0 * inr12\n"
"vmul.f32 q15, q5, q1 @ w0 * inr13\n"
"vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n"
"vmla.f32 q8, q4, q2 @ w3 * inr10\n"
"vmla.f32 q9, q4, q3 @ w3 * inr11\n"
"vmla.f32 q10, q4, q0 @ w3 * inr12\n"
"vmla.f32 q11, q4, q1 @ w3 * inr13\n"
/* mul r1 with w1, w4, get out r1, r0 */
"vmla.f32 q8, q5, q3 @ w4 * inr11\n"
"vmla.f32 q12, q6, q3 @ w1 * inr11\n"
"vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n"
"vmla.f32 q9, q5, q0 @ w4 * inr12\n"
"vmla.f32 q13, q6, q0 @ w1 * inr12\n"
"vmla.f32 q10, q5, q1 @ w4 * inr13\n"
"vmla.f32 q14, q6, q1 @ w1 * inr13\n"
"vmla.f32 q11, q5, q2 @ w4 * inr14\n"
"vmla.f32 q15, q6, q2 @ w1 * inr14\n"
"vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n"
/* mul r1 with w2, w5, get out r1, r0 */
"vmla.f32 q12, q7, q0 @ w2 * inr12\n"
"vmla.f32 q13, q7, q1 @ w2 * inr13\n"
"vmla.f32 q8, q6, q0 @ w5 * inr12\n"
"vmla.f32 q9, q6, q1 @ w5 * inr13\n"
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n"
"vmla.f32 q14, q7, q2 @ w2 * inr14\n"
"vmla.f32 q15, q7, q3 @ w2 * inr15\n"
"vmla.f32 q10, q6, q2 @ w5 * inr14\n"
"vmla.f32 q11, q6, q3 @ w5 * inr15\n"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n"
/* mul r2 with w3-w8, get out r0, r1 */
"vmla.f32 q12, q4, q0 @ w3 * inr20\n"
"vmla.f32 q13, q4, q1 @ w3 * inr21\n"
"vmla.f32 q14, q4, q2 @ w3 * inr22\n"
"vmla.f32 q15, q4, q3 @ w3 * inr23\n"
"vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n"
"vmla.f32 q8, q7, q0 @ w6 * inr20\n"
"vmla.f32 q9, q7, q1 @ w6 * inr21\n"
"vmla.f32 q10, q7, q2 @ w6 * inr22\n"
"vmla.f32 q11, q7, q3 @ w6 * inr23\n"
/* mul r2 with w4, w7, get out r1, r0 */
"vmla.f32 q8, q4, q1 @ w7 * inr21\n"
"vmla.f32 q12, q5, q1 @ w4 * inr21\n"
"vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n"
"vmla.f32 q9, q4, q2 @ w7 * inr22\n"
"vmla.f32 q13, q5, q2 @ w4 * inr22\n"
"vmla.f32 q10, q4, q3 @ w7 * inr23\n"
"vmla.f32 q14, q5, q3 @ w4 * inr23\n"
"vmla.f32 q11, q4, q0 @ w7 * inr24\n"
"vmla.f32 q15, q5, q0 @ w4 * inr24\n"
"vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n"
/* mul r1 with w5, w8, get out r1, r0 */
"vmla.f32 q12, q6, q2 @ w5 * inr22\n"
"vmla.f32 q13, q6, q3 @ w5 * inr23\n"
"vmla.f32 q8, q5, q2 @ w8 * inr22\n"
"vmla.f32 q9, q5, q3 @ w8 * inr23\n"
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n"
"vmla.f32 q14, q6, q0 @ w5 * inr24\n"
"vmla.f32 q15, q6, q1 @ w5 * inr25\n"
"vmla.f32 q10, q5, q0 @ w8 * inr24\n"
"vmla.f32 q11, q5, q1 @ w8 * inr25\n"
"vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n"
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n"
/* mul r3 with w6, w7, w8, get out r1 */
"vmla.f32 q12, q7, q2 @ w6 * inr30\n"
"vmla.f32 q13, q7, q3 @ w6 * inr31\n"
"vst1.32 {d16-d19}, [%[out0]]! @ save r00, r01, c0~c3\n"
"vmla.f32 q14, q7, q0 @ w6 * inr32\n"
"vmla.f32 q15, q7, q1 @ w6 * inr33\n"
"vst1.32 {d20-d23}, [%[out0]]! @ save r02, r03, c0~c3\n"
"vmla.f32 q12, q4, q3 @ w7 * inr31\n"
"vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n"
"vmla.f32 q13, q4, q0 @ w7 * inr32\n"
"vmla.f32 q14, q4, q1 @ w7 * inr33\n"
"vmla.f32 q15, q4, q2 @ w7 * inr34\n"
"vmla.f32 q12, q5, q0 @ w8 * inr32\n"
"vmla.f32 q13, q5, q1 @ w8 * inr33\n"
"vmla.f32 q14, q5, q2 @ w8 * inr34\n"
"vmla.f32 q15, q5, q3 @ w8 * inr35\n"
"vst1.32 {d24-d27}, [%[out0]]! @ save r10, r11, c0~c3\n"
"vst1.32 {d28-d31}, [%[out0]]! @ save r12, r13, c0~c3\n"
: [r0] "+r"(inr0), [r1] "+r"(inr1),
[r2] "+r"(inr2), [r3] "+r"(inr3),
[out0] "+r"(out0), [wc0] "+r"(weight_c)
:
: "cc", "memory",
"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13","q14", "q15"
);
#endif // __arch64__
float* out1 = pre_out;
if (flag_mask) {
c00 = outc00;
c01 = outc01;
c10 = outc10;
c11 = outc11;
c20 = outc20;
c21 = outc21;
c30 = outc30;
c31 = outc31;
outc00 = pre_out;
outc01 = pre_out + 4;
outc10 = pre_out + 8;
outc11 = pre_out + 12;
outc20 = pre_out + 16;
outc21 = pre_out + 20;
outc30 = pre_out + 24;
outc31 = pre_out + 28;
}
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[din]], #32\n" /* load input*/
"ldp q2, q3, [%[din]], #32\n" /* load input*/
"fadd v15.4s, v0.4s, %[vbias].4s\n" /* add bias */
"fadd v16.4s, v1.4s, %[vbias].4s\n" /* add bias */
"ldp q4, q5, [%[din]], #32\n" /* load input*/
"fadd v17.4s, v2.4s, %[vbias].4s\n" /* add bias */
"fadd v18.4s, v3.4s, %[vbias].4s\n" /* add bias */
"ldp q6, q7, [%[din]]\n" /* load input*/
"fadd v19.4s, v4.4s, %[vbias].4s\n" /* add bias */
"fadd v20.4s, v5.4s, %[vbias].4s\n" /* add bias */
"fadd v21.4s, v6.4s, %[vbias].4s\n" /* add bias */
"fadd v22.4s, v7.4s, %[vbias].4s\n" /* add bias */
/* transpose */
"trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/
"trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/
"trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/
"trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/
"trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/
"trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/
"trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/
"trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/
"trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/
"trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/
"trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/
"trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
"trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/
"trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/
"trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/
"trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/
"cbz %w[flag_relu], 0f\n" /* skip relu*/
"movi v0.4s, #0\n" /* for relu */
"fmax v15.4s, v15.4s, v0.4s\n"
"fmax v16.4s, v16.4s, v0.4s\n"
"fmax v17.4s, v17.4s, v0.4s\n"
"fmax v18.4s, v18.4s, v0.4s\n"
"fmax v19.4s, v19.4s, v0.4s\n"
"fmax v20.4s, v20.4s, v0.4s\n"
"fmax v21.4s, v21.4s, v0.4s\n"
"fmax v22.4s, v22.4s, v0.4s\n"
"0:\n"
"str q15, [%[outc00]], #16\n" /* save outc00*/
"str q16, [%[outc01]], #16\n" /* save outc01*/
"str q17, [%[outc10]], #16\n" /* save outc10*/
"str q18, [%[outc11]], #16\n" /* save outc11*/
"str q19, [%[outc20]], #16\n" /* save outc20*/
"str q20, [%[outc21]], #16\n" /* save outc21*/
"str q21, [%[outc30]], #16\n" /* save outc30*/
"str q22, [%[outc31]], #16\n" /* save outc31*/
:[outc00] "+r"(outc00), [outc01] "+r"(outc01),
[outc10] "+r"(outc10), [outc11] "+r"(outc11),
[outc20] "+r"(outc20), [outc21] "+r"(outc21),
[outc30] "+r"(outc30), [outc31] "+r"(outc31),
[din] "+r"(out1)
:[vbias]"w" (vbias), [flag_relu] "r"(flag_relu)
: "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v15", "v16","v17","v18","v19","v20","v21","v22"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[din]]!\n" /* load input*/
"vld1.32 {d4-d7}, [%[din]]!\n" /* load input*/
"vadd.f32 q0, q0, %q[vbias]\n" /* add bias */
"vadd.f32 q1, q1, %q[vbias]\n" /* add bias */
"vld1.32 {d8-d11}, [%[din]]!\n" /* load input*/
"vadd.f32 q2, q2, %q[vbias]\n" /* add bias */
"vadd.f32 q3, q3, %q[vbias]\n" /* add bias */
"vld1.32 {d12-d15}, [%[din]]!\n" /* load input*/
"vadd.f32 q4, q4, %q[vbias]\n" /* add bias */
"vadd.f32 q5, q5, %q[vbias]\n" /* add bias */
"vadd.f32 q6, q6, %q[vbias]\n" /* add bias */
"vadd.f32 q7, q7, %q[vbias]\n" /* add bias */
/* transpose */
"vtrn.32 q0, q1\n" /* r0: q0: a0a1c0c1, q1: b0b1d0d1*/
"vtrn.32 q2, q3\n" /* r0: q2: a2a3c2c3, q3: b2b3d2d3*/
"vtrn.32 q4, q5\n" /* r1: q4: a0a1c0c1, q5: b0b1d0d1*/
"vtrn.32 q6, q7\n" /* r1: q6: a2a3c2c3, q7: b2b3d2d3*/
"vswp d1, d4\n" /* r0: q0: a0a1a2a3, q2: c0c1c2c3*/
"vswp d3, d6\n" /* r0: q1: b0b1b2b3, q3: d0d1d2d3*/
"vswp d9, d12\n" /* r1: q4: a0a1a2a3, q6: c0c1c2c3*/
"vswp d11, d14\n" /* r1: q5: b0b1b2b3, q7: d0d1d2d3*/
"cmp %[flag_relu], #0\n"
"beq 0f\n" /* skip relu*/
"vmov.u32 q15, #0\n"
"vmax.f32 q0, q0, q15\n"
"vmax.f32 q1, q1, q15\n"
"vmax.f32 q2, q2, q15\n"
"vmax.f32 q3, q3, q15\n"
"vmax.f32 q4, q4, q15\n"
"vmax.f32 q5, q5, q15\n"
"vmax.f32 q6, q6, q15\n"
"vmax.f32 q7, q7, q15\n"
"0:\n"
"vst1.32 {d0-d1}, [%[outc00]]!\n" /* save outc00*/
"vst1.32 {d2-d3}, [%[outc10]]!\n" /* save outc10*/
"vst1.32 {d4-d5}, [%[outc20]]!\n" /* save outc20*/
"vst1.32 {d6-d7}, [%[outc30]]!\n" /* save outc30*/
"vst1.32 {d8-d9}, [%[outc01]]!\n" /* save outc01*/
"vst1.32 {d10-d11}, [%[outc11]]!\n" /* save outc11*/
"vst1.32 {d12-d13}, [%[outc21]]!\n" /* save outc21*/
"vst1.32 {d14-d15}, [%[outc31]]!\n" /* save outc31*/
:[outc00] "+r"(outc00), [outc01] "+r"(outc01),
[outc10] "+r"(outc10), [outc11] "+r"(outc11),
[outc20] "+r"(outc20), [outc21] "+r"(outc21),
[outc30] "+r"(outc30), [outc31] "+r"(outc31),
[din] "+r"(out1)
:[vbias]"w" (vbias), [flag_relu] "r"(flag_relu)
: "cc", "memory",
"q0","q1","q2","q3","q4","q5","q6","q7", "q15"
);
#endif // __aarch64__
// clang-format on
if (flag_mask) {
for (int i = 0; i < remain; ++i) {
c00[i] = pre_out[i];
c01[i] = pre_out[i + 4];
c10[i] = pre_out[i + 8];
c11[i] = pre_out[i + 12];
c20[i] = pre_out[i + 16];
c21[i] = pre_out[i + 20];
c30[i] = pre_out[i + 24];
c31[i] = pre_out[i + 28];
}
}
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_depthwise.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/operators/op_params.h"
#ifdef ARM_WITH_OMP
#include <omp.h>
#endif
namespace paddle {
namespace lite {
namespace arm {
namespace math {
#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
template <typename Dtype>
void conv_depthwise_3x3s1_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx) {
const int threads = ctx->threads();
int llc_size = ctx->llc_size() / 4;
const int hout_c_block = 8;
const int hout_r_kernel = 1;
const int wout_block = 4;
const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block;
const int win_round = wout_round + 2;
//! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2
int hout_r_block =
(llc_size - 2 * win_round * threads) /
(win_round * threads + hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block;
const int hin_r_block = hout_r_block + 2;
auto tmp_work_space = ctx->workspace_data<int8_t>();
int8_t ptr_zero[win_round]; // NOLINT
memset(ptr_zero, 0, sizeof(int8_t) * win_round);
Dtype ptr_write[wout_round]; // NOLINT
int in_len = win_round * hout_c_block;
int pre_in_size = hin_r_block * in_len;
pre_in_size = ROUNDUP(pre_in_size, 4);
int pre_out_size = hout_c_block * hout_r_block * wout_round;
int8_t* tmp_din = tmp_work_space;
int size_in_channel = win * hin;
int size_out_channel = wout * hout;
int w_stride = 9; // kernel_w * kernel_h;
int ws = -padw;
int we = ws + win_round;
int w_loop = wout_round / 4;
int chout = chin;
int out_row_stride = hout_c_block * wout_round;
for (int n = 0; n < num; ++n) {
const int8_t* din_batch = din + n * chin * size_in_channel;
int8_t* dout_batch = reinterpret_cast<int8_t*>(dout) +
n * chout * size_out_channel * sizeof(Dtype);
for (int h = 0; h < hout; h += hout_r_block) {
int h_kernel = hout_r_block;
if (h + hout_r_block > hout) {
h_kernel = hout - h;
}
int hs = h - padh;
int he = hs + h_kernel + 2;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < chout; c += hout_c_block) {
#ifdef ARM_WITH_OMP
int8_t* pre_din =
tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size * 4);
int32_t* pre_out = reinterpret_cast<int*>(pre_din + pre_in_size);
#else
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif
prepack_input_nxw_c8_int8(din_batch,
pre_din,
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len;
const int8_t* weight_c = weights + c * w_stride;
float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
bias_local[4] = bias[c + 4];
bias_local[5] = bias[c + 5];
bias_local[6] = bias[c + 6];
bias_local[7] = bias[c + 7];
}
#ifdef __aarch64__
int8x8_t vw0 = vld1_s8(weight_c);
int8x8_t vw1 = vld1_s8(weight_c + 8);
int8x8_t vw2 = vld1_s8(weight_c + 16);
int8x8_t vw3 = vld1_s8(weight_c + 24);
int8x8_t vw4 = vld1_s8(weight_c + 32);
int8x8_t vw5 = vld1_s8(weight_c + 40);
int8x8_t vw6 = vld1_s8(weight_c + 48);
int8x8_t vw7 = vld1_s8(weight_c + 56);
int8x8_t vw8 = vld1_s8(weight_c + 64);
#endif
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
int cnt = w_loop;
const int8_t* inr0 = block_inr0;
const int8_t* inr1 = block_inr1;
const int8_t* inr2 = block_inr2;
int32_t* ptr_out0 = pre_out + hk * out_row_stride;
#ifdef __aarch64__
asm volatile(
"ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n"
"1:\n"
/* inr0 -> outr0 */
"ldp d4, d5, [%[r0]]\n" /* load r0, 4 */
"smull v20.8h, v0.8b, %[w0].8b\n" /* int16, out0 */
"smull v21.8h, v1.8b, %[w0].8b\n" /* int16, out1 */
"smull v22.8h, v2.8b, %[w0].8b\n" /* int16, out2 */
"smull v23.8h, v3.8b, %[w0].8b\n" /* int16, out3 */
"smlal v20.8h, v1.8b, %[w1].8b\n" /* int16, out0 */
"smlal v21.8h, v2.8b, %[w1].8b\n" /* int16, out1 */
"smlal v22.8h, v3.8b, %[w1].8b\n" /* int16, out2 */
"smlal v23.8h, v4.8b, %[w1].8b\n" /* int16, out3 */
"ldp d0, d1, [%[r1]], #16\n" /* load r1, 0,1 */
"sxtl v24.4s, v20.4h\n"
"sxtl2 v25.4s, v20.8h\n"
"sxtl v26.4s, v21.4h\n"
"sxtl2 v27.4s, v21.8h\n"
"sxtl v28.4s, v22.4h\n"
"sxtl2 v29.4s, v22.8h\n"
"sxtl v30.4s, v23.4h\n"
"sxtl2 v31.4s, v23.8h\n"
"smull v20.8h, v2.8b, %[w2].8b\n" /* int16, out0 */
"smull v21.8h, v3.8b, %[w2].8b\n" /* int16, out1 */
"smull v22.8h, v4.8b, %[w2].8b\n" /* int16, out2 */
"smull v23.8h, v5.8b, %[w2].8b\n" /* int16, out3 */
"ldp d2, d3, [%[r1]], #16\n" /* load r1, 2,3 */
"smlal v20.8h, v0.8b, %[w3].8b\n" /* int16, out0 */
"smlal v21.8h, v1.8b, %[w3].8b\n" /* int16, out1 */
"smlal v22.8h, v2.8b, %[w3].8b\n" /* int16, out2 */
"smlal v23.8h, v3.8b, %[w3].8b\n" /* int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n"
"saddw2 v25.4s, v25.4s, v20.8h\n"
"saddw v26.4s, v26.4s, v21.4h\n"
"saddw2 v27.4s, v27.4s, v21.8h\n"
"ldp d4, d5, [%[r1]]\n" /* load r1, 4,5 */
"saddw v28.4s, v28.4s, v22.4h\n"
"saddw2 v29.4s, v29.4s, v22.8h\n"
"saddw v30.4s, v30.4s, v23.4h\n"
"saddw2 v31.4s, v31.4s, v23.8h\n"
"smull v20.8h, v1.8b, %[w4].8b\n" /* int16, out0 */
"smull v21.8h, v2.8b, %[w4].8b\n" /* int16, out1 */
"smull v22.8h, v3.8b, %[w4].8b\n" /* int16, out1 */
"smull v23.8h, v4.8b, %[w4].8b\n" /* int16, out1 */
"ldp d0, d1, [%[r2]], #16\n" /* load r2, 0,1 */
"smlal v20.8h, v2.8b, %[w5].8b\n" /* int16, out0 */
"smlal v21.8h, v3.8b, %[w5].8b\n" /* int16, out1 */
"smlal v22.8h, v4.8b, %[w5].8b\n" /* int16, out2 */
"smlal v23.8h, v5.8b, %[w5].8b\n" /* int16, out3 */
"ldp d2, d3, [%[r2]], #16\n" /* load r2, 2,3 */
"saddw v24.4s, v24.4s, v20.4h\n"
"saddw2 v25.4s, v25.4s, v20.8h\n"
"saddw v26.4s, v26.4s, v21.4h\n"
"saddw2 v27.4s, v27.4s, v21.8h\n"
"ldp d4, d5, [%[r2]]\n" /* load r2 */
"saddw v28.4s, v28.4s, v22.4h\n"
"saddw2 v29.4s, v29.4s, v22.8h\n"
"saddw v30.4s, v30.4s, v23.4h\n"
"saddw2 v31.4s, v31.4s, v23.8h\n"
"smull v20.8h, v0.8b, %[w6].8b\n" /* int16, out0 */
"smull v21.8h, v1.8b, %[w6].8b\n" /* int16, out1 */
"smull v22.8h, v2.8b, %[w6].8b\n" /* int16, out1 */
"smull v23.8h, v3.8b, %[w6].8b\n" /* int16, out1 */
"smlal v20.8h, v1.8b, %[w7].8b\n" /* int16, out0 */
"smlal v21.8h, v2.8b, %[w7].8b\n" /* int16, out1 */
"smlal v22.8h, v3.8b, %[w7].8b\n" /* int16, out1 */
"smlal v23.8h, v4.8b, %[w7].8b\n" /* int16, out1 */
"ldp d0, d1, [%[r0]], #16\n" /* load r0, 0,1 */
"saddw v24.4s, v24.4s, v20.4h\n"
"saddw2 v25.4s, v25.4s, v20.8h\n"
"saddw v26.4s, v26.4s, v21.4h\n"
"saddw2 v27.4s, v27.4s, v21.8h\n"
"saddw v28.4s, v28.4s, v22.4h\n"
"saddw2 v29.4s, v29.4s, v22.8h\n"
"saddw v30.4s, v30.4s, v23.4h\n"
"saddw2 v31.4s, v31.4s, v23.8h\n"
"smull v20.8h, v2.8b, %[w8].8b\n" /* int16, out0 */
"smull v21.8h, v3.8b, %[w8].8b\n" /* int16, out1 */
"smull v22.8h, v4.8b, %[w8].8b\n" /* int16, out1 */
"smull v23.8h, v5.8b, %[w8].8b\n" /* int16, out1 */
"ldp d2, d3, [%[r0]], #16\n" /* load r0, 2,3 */
"saddw v24.4s, v24.4s, v20.4h\n"
"saddw2 v25.4s, v25.4s, v20.8h\n"
"saddw v26.4s, v26.4s, v21.4h\n"
"saddw2 v27.4s, v27.4s, v21.8h\n"
"stp q24, q25, [%[ptr_out0]], #32\n"
"saddw v28.4s, v28.4s, v22.4h\n"
"saddw2 v29.4s, v29.4s, v22.8h\n"
"stp q26, q27, [%[ptr_out0]], #32\n"
"saddw v30.4s, v30.4s, v23.4h\n"
"saddw2 v31.4s, v31.4s, v23.8h\n"
"subs %w[cnt], %w[cnt], #1\n"
"stp q28, q29, [%[ptr_out0]], #32\n"
"stp q30, q31, [%[ptr_out0]], #32\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[ptr_out0] "+r"(ptr_out0)
: [w0] "w"(vw0),
[w1] "w"(vw1),
[w2] "w"(vw2),
[w3] "w"(vw3),
[w4] "w"(vw4),
[w5] "w"(vw5),
[w6] "w"(vw6),
[w7] "w"(vw7),
[w8] "w"(vw8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"v26",
"v27",
"v28",
"v29",
"v30",
"v31"
);
#else
auto wptr = weight_c;
asm volatile(
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-4 */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */
"1:\n"
/* inr0 -> outr0 */
"vld1.32 {d4-d5}, [%[r0]]\n" /* load r0, 5-6 */
"vmull.s8 q4, d0, d6\n" /* int16, out0 */
"vmull.s8 q5, d1, d6\n" /* int16, out1 */
"vmull.s8 q6, d2, d6\n" /* int16, out2 */
"vmull.s8 q7, d3, d6\n" /* int16, out3 */
"vld1.32 {d6}, [%[wptr]]!\n" /* load w2 */
"vmlal.s8 q4, d1, d7\n" /* int16, out0 */
"vmlal.s8 q5, d2, d7\n" /* int16, out1 */
"vmlal.s8 q6, d3, d7\n" /* int16, out2 */
"vmlal.s8 q7, d4, d7\n" /* int16, out3 */
"vld1.32 {d7}, [%[wptr]]!\n" /* load w3 */
"vmovl.s16 q8, d8\n"
"vmovl.s16 q9, d9\n"
"vmovl.s16 q10, d10\n"
"vmovl.s16 q11, d11\n"
"vld1.32 {d0-d1}, [%[r1]]!\n" /* load r1, 0-1 */
"vmovl.s16 q12, d12\n"
"vmovl.s16 q13, d13\n"
"vmovl.s16 q14, d14\n"
"vmovl.s16 q15, d15\n"
"vmull.s8 q4, d2, d6\n" /* int16, out0 */
"vmull.s8 q5, d3, d6\n" /* int16, out1 */
"vld1.32 {d2-d3}, [%[r1]]!\n" /* load r1, 2-3 */
"vmull.s8 q6, d4, d6\n" /* int16, out2 */
"vmull.s8 q7, d5, d6\n" /* int16, out3 */
"vld1.32 {d6}, [%[wptr]]!\n" /* load w4 */
/* inr1 -> outr0 */
"vmlal.s8 q4, d0, d7\n" /* int16, out0 */
"vmlal.s8 q5, d1, d7\n" /* int16, out1 */
"vmlal.s8 q6, d2, d7\n" /* int16, out2 */
"vmlal.s8 q7, d3, d7\n" /* int16, out3 */
"vld1.32 {d4-d5}, [%[r1]]\n" /* load r1, 4-5 */
"vaddw.s16 q8, q8, d8\n"
"vaddw.s16 q9, q9, d9\n"
"vaddw.s16 q10, q10, d10\n"
"vaddw.s16 q11, q11, d11\n"
"vld1.32 {d7}, [%[wptr]]!\n" /* load w5 */
"vaddw.s16 q12, q12, d12\n"
"vaddw.s16 q13, q13, d13\n"
"vaddw.s16 q14, q14, d14\n"
"vaddw.s16 q15, q15, d15\n"
"vmull.s8 q4, d1, d6\n" /* int16, out0 */
"vmull.s8 q5, d2, d6\n" /* int16, out1 */
"vmull.s8 q6, d3, d6\n" /* int16, out2 */
"vmull.s8 q7, d4, d6\n" /* int16, out3 */
"vld1.32 {d6}, [%[wptr]]!\n" /* load w6 */
"vld1.32 {d0-d1}, [%[r2]]!\n" /* load r2, 0-1 */
"vmlal.s8 q4, d2, d7\n" /* int16, out0 */
"vmlal.s8 q5, d3, d7\n" /* int16, out1 */
"vmlal.s8 q6, d4, d7\n" /* int16, out2 */
"vmlal.s8 q7, d5, d7\n" /* int16, out3 */
"vld1.32 {d7}, [%[wptr]]!\n" /* load w7 */
"vaddw.s16 q8, q8, d8\n"
"vaddw.s16 q9, q9, d9\n"
"vaddw.s16 q10, q10, d10\n"
"vaddw.s16 q11, q11, d11\n"
"vld1.32 {d2-d3}, [%[r2]]!\n" /* load r2, 2-3 */
"vaddw.s16 q12, q12, d12\n"
"vaddw.s16 q13, q13, d13\n"
"vaddw.s16 q14, q14, d14\n"
"vaddw.s16 q15, q15, d15\n"
"vld1.32 {d4-d5}, [%[r2]]\n" /* load r2, 4-5 */
/* inr2 -> outr0 */
"vmull.s8 q4, d0, d6\n" /* int16, out0 */
"vmull.s8 q5, d1, d6\n" /* int16, out1 */
"vmull.s8 q6, d2, d6\n" /* int16, out2 */
"vmull.s8 q7, d3, d6\n" /* int16, out3 */
"vld1.32 {d6}, [%[wptr]]!\n" /* load w8 */
"vmlal.s8 q4, d1, d7\n" /* int16, out0 */
"vmlal.s8 q5, d2, d7\n" /* int16, out1 */
"vmlal.s8 q6, d3, d7\n" /* int16, out2 */
"vmlal.s8 q7, d4, d7\n" /* int16, out3 */
"vaddw.s16 q8, q8, d8\n"
"vaddw.s16 q9, q9, d9\n"
"vaddw.s16 q10, q10, d10\n"
"vaddw.s16 q11, q11, d11\n"
"vld1.32 {d0-d1}, [%[r0]]!\n" /* load r0, 0-1 */
"vaddw.s16 q12, q12, d12\n"
"vaddw.s16 q13, q13, d13\n"
"vaddw.s16 q14, q14, d14\n"
"vaddw.s16 q15, q15, d15\n"
"sub %[wptr], %[wptr], #72\n"
"vmull.s8 q4, d2, d6\n" /* int16, out0 */
"vmull.s8 q5, d3, d6\n" /* int16, out1 */
"vmull.s8 q6, d4, d6\n" /* int16, out2 */
"vmull.s8 q7, d5, d6\n" /* int16, out3 */
"vld1.32 {d2-d3}, [%[r0]]!\n" /* load r0, 2-3 */
"vaddw.s16 q8, q8, d8\n"
"vaddw.s16 q9, q9, d9\n"
"vaddw.s16 q10, q10, d10\n"
"vaddw.s16 q11, q11, d11\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]!\n"
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */
"vaddw.s16 q12, q12, d12\n"
"vaddw.s16 q13, q13, d13\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]!\n"
"vaddw.s16 q14, q14, d14\n"
"vaddw.s16 q15, q15, d15\n"
"subs %[cnt], #1\n"
"vst1.32 {d24-d27}, [%[ptr_out0]]!\n"
"vst1.32 {d28-d31}, [%[ptr_out0]]!\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[ptr_out0] "+r"(ptr_out0),
[wptr] "+r"(wptr)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
block_inr0 = block_inr1;
block_inr1 = block_inr2;
block_inr2 = block_inr1 + in_len;
}
write_int32_nchwc8_to_nchw<Dtype>(pre_out,
reinterpret_cast<Dtype*>(dout_batch),
c,
c + hout_c_block,
h,
h + h_kernel,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
bias_local,
flag_bias,
ptr_write,
scale + c);
}
}
}
}
template void conv_depthwise_3x3s1_int8<int8_t>(int8_t* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
template void conv_depthwise_3x3s1_int8<float>(float* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -51,12 +51,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -51,12 +51,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const int win_round = wout_round + 2; const int win_round = wout_round + 2;
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
// if (param.activation_param.has_active) {
// if (param.activation_param.active == Active_relu &&
// fabs(param.activation_param.negative_slope) < 1e-6f) {
// flag_relu = true;
// }
// }
int hout_r_block = (l2_size - 2 * win_round * ic) / int hout_r_block = (l2_size - 2 * win_round * ic) /
(win_round * ic + hout_c_block * wout_round * threads); (win_round * ic + hout_c_block * wout_round * threads);
hout_r_block = hout_r_block > oh ? oh : hout_r_block; hout_r_block = hout_r_block > oh ? oh : hout_r_block;
......
...@@ -26,9 +26,9 @@ namespace lite { ...@@ -26,9 +26,9 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
#ifdef __aarch64__ template <typename Dtype>
void conv_3x3s1_direct_int8(const int8_t* din, void conv_3x3s1_direct_int8(const int8_t* din,
int32_t* dout, Dtype* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
...@@ -37,62 +37,74 @@ void conv_3x3s1_direct_int8(const int8_t* din, ...@@ -37,62 +37,74 @@ void conv_3x3s1_direct_int8(const int8_t* din,
int hin, int hin,
int win, int win,
const int8_t* weights, const int8_t* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, Context<TARGET(kARM)>* ctx,
PrecisionType out_type,
const float* scale) { const float* scale) {
const int hin_r_block = 4;
const int hout_c_block = 4; // 8;
const int hout_r_block = 2;
int stride_w = param.strides[1];
int pad_w = param.paddings[1];
int pad_h = param.paddings[0];
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = (param.bias != nullptr); bool flag_bias = param.bias;
int pad_h = param.paddings[0];
int wout_round = ((wout + 3) / 4) * 4; int pad_w = param.paddings[1];
int win_round = wout_round * stride_w + 4;
int threads = ctx->threads();
int* tmp_work_space = ctx->workspace_data<int>(); const int threads = ctx->threads();
int* ptr_zero = tmp_work_space; int llc_size = ctx->llc_size() / 4;
memset(ptr_zero, 0, sizeof(int) * win_round);
int* ptr_write = ptr_zero + win_round; const int hout_c_block = 4;
const int hout_r_kernel = 2;
const int wout_block = 4;
const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block;
const int win_round = wout_round + 2;
//! get h block
//! llc_size = win_round * chin * hin_r_block * sizeof(int8_t) + wout_round *
//! hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2
int hout_r_block =
(llc_size - 2 * win_round * chin) /
(win_round * chin + hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block;
const int hin_r_block = hout_r_block + 2;
auto tmp_work_space = ctx->workspace_data<int8_t>();
int8_t ptr_zero[win_round]; // NOLINT
memset(ptr_zero, 0, sizeof(int8_t) * win_round);
Dtype ptr_write[wout_round]; // NOLINT
int in_len = win_round * chin; int in_len = win_round * chin;
int pre_in_size = hin_r_block * in_len; int pre_in_size = hin_r_block * in_len;
pre_in_size = ROUNDUP(pre_in_size, 4);
int pre_out_size = hout_c_block * hout_r_block * wout_round; int pre_out_size = hout_c_block * hout_r_block * wout_round;
signed char* pre_din = reinterpret_cast<signed char*>(ptr_write + wout_round); int8_t* pre_din = tmp_work_space;
int size_in_channel = win * hin; int size_in_channel = win * hin;
int size_out_channel = wout * hout; int size_out_channel = wout * hout;
int w_stride = chin * 9; int w_stride = chin * 9; // kernel_w * kernel_h;
int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h *
int ws = -pad_w; int ws = -pad_w;
int we = ws + win_round; int we = ws + win_round;
int w_loop = wout_round / 4; int w_loop = wout_round / 4;
int size_out = wout_round * hout_c_block; int out_row_stride = hout_c_block * wout_round;
// printf("win_round: %d, wout_round: %d, ws: %d, we: %d\n", win_round,
// wout_round, ws, we);
// here
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const signed char* din_batch = const int8_t* din_batch = din + n * chin * size_in_channel;
static_cast<const signed char*>(din) + n * chin * size_in_channel; Dtype* dout_batch = dout + n * chout * size_out_channel;
signed char* dout_batch = for (int h = 0; h < hout; h += hout_r_block) {
reinterpret_cast<signed char*>(dout) + int h_kernel = hout_r_block;
n * chout * size_out_channel * PrecisionTypeLength(out_type); if (h + hout_r_block > hout) {
h_kernel = hout - h;
}
for (int h = 0; h < hout; h += 2) {
int hs = h - pad_h; int hs = h - pad_h;
int he = hs + 4; int he = hs + h_kernel + 2;
// printf("hs: %d, he: %d, chin: %d, hin: %d, win: %d \n", hs, he, chin,
// hin, win);
prepack_input_nxw(din_batch, prepack_input_nxw(din_batch,
pre_din, pre_din,
0, 0,
...@@ -104,701 +116,370 @@ void conv_3x3s1_direct_int8(const int8_t* din, ...@@ -104,701 +116,370 @@ void conv_3x3s1_direct_int8(const int8_t* din,
chin, chin,
win, win,
hin, hin,
(signed char*)ptr_zero); ptr_zero);
#pragma omp parallel for num_threads(threads) #pragma omp parallel for num_threads(threads)
for (int c = 0; c < chout; c += hout_c_block) { for (int c = 0; c < chout; c += hout_c_block) {
#ifdef ARM_WITH_OMP #ifdef ARM_WITH_OMP
int* pre_out = int32_t* pre_out = reinterpret_cast<int*>(pre_din + pre_in_size) +
reinterpret_cast<int*>(pre_din + (pre_in_size + 3) / 4 * 4) + omp_get_thread_num() * pre_out_size;
omp_get_thread_num() * pre_out_size;
#else #else
int* pre_out = auto pre_out = reinterpret_cast<int32_t*>(pre_din + pre_in_size);
reinterpret_cast<int*>(pre_din + (pre_in_size + 3) / 4 * 4);
#endif #endif
// printf("ptr_zero_int: %x, ptr_zero: %x, ptr_write: %x, pre_din: %x, const int8_t* block_inr0 = pre_din;
// pre_out: %x \n", ptr_zero_int, ptr_zero, ptr_write, pre_din, const int8_t* block_inr1 = block_inr0 + in_len;
// pre_out); const int8_t* block_inr2 = block_inr1 + in_len;
const signed char* inr0 = pre_din; const int8_t* block_inr3 = block_inr2 + in_len;
const signed char* inr1 = inr0 + in_len;
const signed char* inr2 = inr1 + in_len;
const signed char* inr3 = inr2 + in_len;
const signed char* wc0 =
static_cast<const signed char*>(weights) + c * w_stride;
const int* bias_ptr = ptr_zero; const int8_t* weight_c = weights + c * w_stride;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) { if (flag_bias) {
bias_ptr = static_cast<const int*>(bias) + c; bias_local[0] = bias[c];
} bias_local[1] = bias[c + 1];
// hout_r_block * wout_round * hout_c_block bias_local[2] = bias[c + 2];
fill_packed_bias_nxmw_int8( bias_local[3] = bias[c + 3];
bias_ptr, pre_out, hout_c_block, hout_r_block, wout_round);
for (int i = 0; i < chin; ++i) {
const signed char* r0 = inr0;
const signed char* r1 = inr1;
const signed char* r2 = inr2;
const signed char* r3 = inr3;
int* ptr_out0 = pre_out;
int* ptr_out1 = pre_out + size_out;
int cnt = w_loop;
const signed char* ptr_wc0 = wc0;
asm volatile(
"ldp q4, q5, [%[wc0]] \n" /* w4 w5 w6 w7 */
"ldr q6, [%[wc0], #32] \n" /* w8 */
"SXTL v11.8h, v4.8b \n" /* w to int16 */
"SXTL2 v12.8h, v4.16b \n" /* w to int16 */
"SXTL v13.8h, v5.8b \n" /* to int16 */
"SXTL2 v14.8h, v5.16b \n" /* to int16 */
"SXTL v15.8h, v6.8b \n" /* to int16 */
"1: \n" /* main loop*/
"ldr d0, [%[r0]] \n" /* load data din0-dinn7*/
"SXTL v1.8h, v0.8b \n" /* to int16 */
/*output 1st row*/
"smull v16.4s, v11.4h, v1.h[0] \n" /* */
"smull v17.4s, v11.4h, v1.h[1] \n" /* */
"smull v18.4s, v11.4h, v1.h[2] \n" /* */
"smull v19.4s, v11.4h, v1.h[3] \n" /* */
"add %[r0], %[r0], #4\n"
/*output 1st row*/
"smlal2 v16.4s, v11.8h, v1.h[1] \n" /* */
"smlal2 v17.4s, v11.8h, v1.h[2] \n" /* */
"smlal2 v18.4s, v11.8h, v1.h[3] \n" /* */
"smlal2 v19.4s, v11.8h, v1.h[4] \n" /* */
"ldr d0, [%[r1]] \n" /* load data */
/*output 1st row*/
"smlal v16.4s, v12.4h, v1.h[2] \n" /* */
"smlal v17.4s, v12.4h, v1.h[3] \n" /* */
"SXTL v2.8h, v0.8b \n" /* to int16 */
"smlal v18.4s, v12.4h, v1.h[4] \n" /* */
"smlal v19.4s, v12.4h, v1.h[5] \n" /* */
"add %[r1], %[r1], #4 \n"
/*output 1st row*/
"smlal2 v16.4s, v12.8h, v2.h[0] \n" /* */
"smlal2 v17.4s, v12.8h, v2.h[1] \n" /* */
"smlal2 v18.4s, v12.8h, v2.h[2] \n" /* */
"smlal2 v19.4s, v12.8h, v2.h[3] \n" /* */
/*output 1st row*/
"smlal v16.4s, v13.4h, v2.h[1] \n" /* */
"smlal v17.4s, v13.4h, v2.h[2] \n" /* */
"smlal v18.4s, v13.4h, v2.h[3] \n" /* */
"smlal v19.4s, v13.4h, v2.h[4] \n" /* */
/*output 1st row*/
"smlal2 v16.4s, v13.8h, v2.h[2] \n" /* */
"smlal2 v17.4s, v13.8h, v2.h[3] \n" /* */
"smlal2 v18.4s, v13.8h, v2.h[4] \n" /* */
"smlal2 v19.4s, v13.8h, v2.h[5] \n" /* */
/*output 2rd row*/
"smull v24.4s, v11.4h, v2.h[0] \n" /* */
"smull v25.4s, v11.4h, v2.h[1] \n" /* */
"smull v26.4s, v11.4h, v2.h[2] \n" /* */
"smull v27.4s, v11.4h, v2.h[3] \n" /* */
/*output 2rd row*/
"smlal2 v24.4s, v11.8h, v2.h[1] \n" /* */
"smlal2 v25.4s, v11.8h, v2.h[2] \n" /* */
"smlal2 v26.4s, v11.8h, v2.h[3] \n" /* */
"smlal2 v27.4s, v11.8h, v2.h[4] \n" /* */
"ldr d0, [%[r2]] \n" /* load data */
/*output 2rd row*/
"smlal v24.4s, v12.4h, v2.h[2] \n" /* */
"smlal v25.4s, v12.4h, v2.h[3] \n" /* */
"SXTL v1.8h, v0.8b \n" /* to int16 */
"smlal v26.4s, v12.4h, v2.h[4] \n" /* */
"smlal v27.4s, v12.4h, v2.h[5] \n" /* */
/*output 1st row*/
"smlal v16.4s, v14.4h, v1.h[0] \n" /* */
"smlal v17.4s, v14.4h, v1.h[1] \n" /* */
"smlal v18.4s, v14.4h, v1.h[2] \n" /* */
"smlal v19.4s, v14.4h, v1.h[3] \n" /* */
"add %[r2], %[r2], #4 \n"
/*output 1st row*/
"smlal2 v16.4s, v14.8h, v1.h[1] \n" /* */
"smlal2 v17.4s, v14.8h, v1.h[2] \n" /* */
"smlal2 v18.4s, v14.8h, v1.h[3] \n" /* */
"smlal2 v19.4s, v14.8h, v1.h[4] \n" /* */
"ldp q3, q4, [%[ptr_out0]] \n"
"ldp q5, q6, [%[ptr_out0], #32] \n"
/*output 1st row*/
"smlal v16.4s, v15.4h, v1.h[2] \n" /* */
"smlal v17.4s, v15.4h, v1.h[3] \n" /* */
"smlal v18.4s, v15.4h, v1.h[4] \n" /* */
"smlal v19.4s, v15.4h, v1.h[5] \n" /* */
"ADD v3.4s, v16.4s, v3.4s \n"
"ADD v4.4s, v17.4s, v4.4s \n"
"ADD v5.4s, v18.4s, v5.4s \n"
"ADD v6.4s, v19.4s, v6.4s \n"
"stp q3, q4, [%[ptr_out0]], #32 \n" /* save to
output*/
"stp q5, q6, [%[ptr_out0]], #32 \n" /* save to
output*/
/*output 2rd row*/
"smlal2 v24.4s, v12.8h, v1.h[0] \n" /* */
"smlal2 v25.4s, v12.8h, v1.h[1] \n" /* */
"smlal2 v26.4s, v12.8h, v1.h[2] \n" /* */
"smlal2 v27.4s, v12.8h, v1.h[3] \n" /* */
/*output 2rd row*/
"smlal v24.4s, v13.4h, v1.h[1] \n" /* */
"smlal v25.4s, v13.4h, v1.h[2] \n" /* */
"smlal v26.4s, v13.4h, v1.h[3] \n" /* */
"smlal v27.4s, v13.4h, v1.h[4] \n" /* */
"ldr d0, [%[r3]] \n" /* load data */
/*output 2rd row*/
"smlal2 v24.4s, v13.8h, v1.h[2] \n" /* */
"smlal2 v25.4s, v13.8h, v1.h[3] \n" /* */
"SXTL v2.8h, v0.8b \n" /* to int16 */
"smlal2 v26.4s, v13.8h, v1.h[4] \n" /* */
"smlal2 v27.4s, v13.8h, v1.h[5] \n" /* */
/*output 2rd row*/
"smlal v24.4s, v14.4h, v2.h[0] \n" /* */
"smlal v25.4s, v14.4h, v2.h[1] \n" /* */
"smlal v26.4s, v14.4h, v2.h[2] \n" /* */
"smlal v27.4s, v14.4h, v2.h[3] \n" /* */
"add %[r3], %[r3], #4 \n"
/*output 2rd row*/
"smlal2 v24.4s, v14.8h, v2.h[1] \n" /* */
"smlal2 v25.4s, v14.8h, v2.h[2] \n" /* */
"smlal2 v26.4s, v14.8h, v2.h[3] \n" /* */
"smlal2 v27.4s, v14.8h, v2.h[4] \n" /* */
"ldp q3, q4, [%[ptr_out1]] \n"
"ldp q5, q6, [%[ptr_out1], #32] \n"
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */
/*output 2rd row*/
"smlal v24.4s, v15.4h, v2.h[2] \n" /* */
"smlal v25.4s, v15.4h, v2.h[3] \n" /* */
"smlal v26.4s, v15.4h, v2.h[4] \n" /* */
"smlal v27.4s, v15.4h, v2.h[5] \n" /* */
"ADD v3.4s, v24.4s, v3.4s \n"
"ADD v4.4s, v25.4s, v4.4s \n"
"ADD v5.4s, v26.4s, v5.4s \n"
"ADD v6.4s, v27.4s, v6.4s \n"
"stp q3, q4, [%[ptr_out1]], #32 \n" /* save to output*/
"stp q5, q6, [%[ptr_out1]], #32 \n" /* save to output*/
"bne 1b \n" /* jump to main loop*/
: [cnt] "+r"(cnt),
[wc0] "+r"(ptr_wc0),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
:
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v24",
"v25",
"v26",
"v27"
);
wc0 += 9 * hout_c_block;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
}
if (out_type == PRECISION(kFloat)) {
write_to_output_c4_int32_1(pre_out,
reinterpret_cast<float*>(dout_batch),
hout_c_block,
hout_r_block,
c,
c + 4,
h,
h + 2,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<float*>(ptr_write),
&scale[c],
out_type);
} else if (out_type == PRECISION(kInt8)) {
write_to_output_c4_int32_1(pre_out,
dout_batch,
hout_c_block,
hout_r_block,
c,
c + 4,
h,
h + 2,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<signed char*>(ptr_write),
&scale[c],
out_type);
} else { // int32
write_to_output_c4_int32(pre_out,
reinterpret_cast<int*>(dout_batch),
hout_c_block,
hout_r_block,
c,
c + 4,
h,
h + 2,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
ptr_write);
} }
} memset(pre_out, 0, pre_out_size * sizeof(int32_t));
} for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
} const int8_t* wc0 = weight_c;
}
const int8_t* inr0 = block_inr0;
#else const int8_t* inr1 = block_inr1;
const int8_t* inr2 = block_inr2;
void conv_3x3s1_direct_int8(const int8_t* din, const int8_t* inr3 = block_inr3;
int32_t* dout,
int num, int32_t* pre_out0 = pre_out + hk * out_row_stride;
int chout, int32_t* pre_out1 = pre_out0 + out_row_stride;
int hout,
int wout, for (int i = 0; i < chin; ++i) {
int chin, int32_t* ptr_out0 = pre_out0;
int hin, int32_t* ptr_out1 = pre_out1;
int win,
const int8_t* weights, const signed char* r0 = inr0;
const int32_t* bias, const signed char* r1 = inr1;
const operators::ConvParam& param, const signed char* r2 = inr2;
Context<TARGET(kARM)>* ctx, const signed char* r3 = inr3;
PrecisionType out_type,
const float* scale) { int cnt = w_loop;
// printf("conv2_3x3s1_direct_int8 \n"); const int8_t* ptr_wc0 = wc0;
// clang-format off
const int hin_r_block = 4; #ifdef __aarch64__
const int hout_c_block = 4; // 8 asm volatile(
const int hout_r_block = 2; "ldp q4, q5, [%[wc0]]\n"
"ldr d6, [%[wc0], #32]\n"
int stride_w = param.strides[1]; "sxtl v11.8h, v4.8b\n"
int pad_w = param.paddings[1]; "sxtl2 v12.8h, v4.16b\n"
int pad_h = param.paddings[0]; "sxtl v13.8h, v5.8b\n"
bool flag_relu = param.fuse_relu; "sxtl2 v14.8h, v5.16b\n"
bool flag_bias = (param.bias != nullptr); "sxtl v15.8h, v6.8b\n"
"ldp q16, q17, [%[ptr_out0]]\n"
int wout_round = ((wout + 3) / 4) * 4; "ldp q18, q19, [%[ptr_out0], #32]\n"
int win_round = wout_round * stride_w + 4; "ldr d0, [%[r1]], #4\n" /* load r1 */
"ldr d1, [%[r2]], #4\n" /* load r2 */
int threads = ctx->threads(); "sxtl v2.8h, v0.8b\n" /* r1, cvt to int16 */
"sxtl v3.8h, v1.8b\n" /* r2, cvt to int16 */
int* tmp_work_space = ctx->workspace_data<int>(); "1:\n"
int* ptr_zero = tmp_work_space; /* inr1 -> outr0, outr1 */
memset(ptr_zero, 0, sizeof(int) * win_round); "ldp q20, q21, [%[ptr_out1]]\n"
int* ptr_write = ptr_zero + win_round; "ldr d0, [%[r0]], #4\n" /* load r0 */
"smlal2 v16.4s, v12.8h, v2.h[0]\n" /* out00, w10 * r10 */
int in_len = win_round * chin; "smlal2 v17.4s, v12.8h, v2.h[1]\n" /* out01, w10 * r11 */
int pre_in_size = hin_r_block * in_len; "smlal2 v18.4s, v12.8h, v2.h[2]\n" /* out02, w10 * r12 */
int pre_out_size = hout_c_block * hout_r_block * wout_round; "smlal2 v19.4s, v12.8h, v2.h[3]\n" /* out03, w10 * r13 */
"ldp q22, q23, [%[ptr_out1], #32]\n"
signed char* pre_din = reinterpret_cast<signed char*>(ptr_write + wout_round); "smlal v16.4s, v13.4h, v2.h[1]\n" /* out00, w11 * r11 */
"smlal v17.4s, v13.4h, v2.h[2]\n" /* out01, w11 * r12 */
int size_in_channel = win * hin; "smlal v18.4s, v13.4h, v2.h[3]\n" /* out02, w11 * r13 */
int size_out_channel = wout * hout; "smlal v19.4s, v13.4h, v2.h[4]\n" /* out03, w11 * r14 */
int w_stride = chin * 9; "smlal2 v16.4s, v13.8h, v2.h[2]\n" /* out00, w12 * r12 */
"smlal2 v17.4s, v13.8h, v2.h[3]\n" /* out01, w12 * r13 */
int ws = -pad_w; "smlal2 v18.4s, v13.8h, v2.h[4]\n" /* out02, w12 * r14 */
int we = ws + win_round; "smlal2 v19.4s, v13.8h, v2.h[5]\n" /* out03, w12 * r15 */
int w_loop = wout_round / 4; "smlal v20.4s, v11.4h, v2.h[0]\n" /* out10, w00 * r10 */
"smlal v21.4s, v11.4h, v2.h[1]\n" /* out11, w00 * r11 */
int size_out = wout_round * hout_c_block; "smlal v22.4s, v11.4h, v2.h[2]\n" /* out12, w00 * r12 */
"smlal v23.4s, v11.4h, v2.h[3]\n" /* out13, w00 * r13 */
// printf("win_round: %d, wout_round: %d, ws: %d, we: %d\n", win_round, "smlal2 v20.4s, v11.8h, v2.h[1]\n" /* out10, w01 * r11 */
// wout_round, ws, we); "smlal2 v21.4s, v11.8h, v2.h[2]\n" /* out11, w01 * r12 */
"smlal2 v22.4s, v11.8h, v2.h[3]\n" /* out12, w01 * r13 */
for (int n = 0; n < num; ++n) { "smlal2 v23.4s, v11.8h, v2.h[4]\n" /* out13, w01 * r14 */
const signed char* din_batch = "smlal v20.4s, v12.4h, v2.h[2]\n" /* out10, w02 * r12 */
static_cast<const signed char*>(din) + n * chin * size_in_channel; "smlal v21.4s, v12.4h, v2.h[3]\n" /* out11, w02 * r13 */
signed char* dout_batch = "smlal v22.4s, v12.4h, v2.h[4]\n" /* out12, w02 * r14 */
reinterpret_cast<signed char*>(dout) + "smlal v23.4s, v12.4h, v2.h[5]\n" /* out13, w02 * r15 */
n * chout * size_out_channel * PrecisionTypeLength(out_type); "sxtl v2.8h, v0.8b\n" /* r0, cvt to int16 */
/* inr2 -> outr0, outr1 */
for (int h = 0; h < hout; h += 2) { "ldr d1, [%[r3]], #4\n" /* load r3 */
int hs = h - pad_h; "smlal v16.4s, v14.4h, v3.h[0]\n" /* out00, w20 * r20 */
int he = hs + 4; "smlal v17.4s, v14.4h, v3.h[1]\n" /* out01, w20 * r21 */
// printf("hs: %d, he: %d, chin: %d, hin: %d, win: %d \n", hs, he, chin, "smlal v18.4s, v14.4h, v3.h[2]\n" /* out02, w20 * r22 */
// hin, win); "smlal v19.4s, v14.4h, v3.h[3]\n" /* out03, w20 * r23 */
prepack_input_nxw(din_batch, "smlal2 v20.4s, v12.8h, v3.h[0]\n" /* out10, w10 * r20 */
pre_din, "smlal2 v21.4s, v12.8h, v3.h[1]\n" /* out11, w10 * r21 */
0, "smlal2 v22.4s, v12.8h, v3.h[2]\n" /* out12, w10 * r22 */
chin, "smlal2 v23.4s, v12.8h, v3.h[3]\n" /* out13, w10 * r23 */
hs, "smlal2 v16.4s, v14.8h, v3.h[1]\n" /* out00, w21 * r21 */
he, "smlal2 v17.4s, v14.8h, v3.h[2]\n" /* out01, w21 * r22 */
ws, "smlal2 v18.4s, v14.8h, v3.h[3]\n" /* out02, w21 * r23 */
we, "smlal2 v19.4s, v14.8h, v3.h[4]\n" /* out03, w21 * r24 */
chin, "smlal v20.4s, v13.4h, v3.h[1]\n" /* out10, w11 * r21 */
win, "smlal v21.4s, v13.4h, v3.h[2]\n" /* out11, w11 * r22 */
hin, "smlal v22.4s, v13.4h, v3.h[3]\n" /* out12, w11 * r23 */
(signed char*)ptr_zero); "smlal v23.4s, v13.4h, v3.h[4]\n" /* out13, w11 * r24 */
"smlal v16.4s, v15.4h, v3.h[2]\n" /* out00, w22 * r22 */
"smlal v17.4s, v15.4h, v3.h[3]\n" /* out01, w22 * r23 */
"smlal v18.4s, v15.4h, v3.h[4]\n" /* out02, w22 * r24 */
"smlal v19.4s, v15.4h, v3.h[5]\n" /* out03, w22 * r25 */
"smlal2 v20.4s, v13.8h, v3.h[2]\n" /* out10, w12 * r22 */
"smlal2 v21.4s, v13.8h, v3.h[3]\n" /* out11, w12 * r23 */
"smlal2 v22.4s, v13.8h, v3.h[4]\n" /* out12, w12 * r24 */
"smlal2 v23.4s, v13.8h, v3.h[5]\n" /* out13, w12 * r25 */
"sxtl v3.8h, v1.8b\n" /* r0, cvt to int16 */
/* inr0 -> outr0 */
"ldr d0, [%[r1]], #4\n" /* load r1 */
"smlal v16.4s, v11.4h, v2.h[0]\n" /* out00, w00 * r00 */
"smlal v17.4s, v11.4h, v2.h[1]\n" /* out01, w00 * r01 */
"smlal v18.4s, v11.4h, v2.h[2]\n" /* out02, w00 * r02 */
"smlal v19.4s, v11.4h, v2.h[3]\n" /* out03, w00 * r03 */
"smlal2 v16.4s, v11.8h, v2.h[1]\n" /* out00, w01 * r01 */
"smlal2 v17.4s, v11.8h, v2.h[2]\n" /* out01, w01 * r02 */
"smlal2 v18.4s, v11.8h, v2.h[3]\n" /* out02, w01 * r03 */
"smlal2 v19.4s, v11.8h, v2.h[4]\n" /* out03, w01 * r04 */
"smlal v16.4s, v12.4h, v2.h[2]\n" /* out00, w02 * r02 */
"smlal v17.4s, v12.4h, v2.h[3]\n" /* out01, w02 * r03 */
"smlal v18.4s, v12.4h, v2.h[4]\n" /* out02, w02 * r04 */
"smlal v19.4s, v12.4h, v2.h[5]\n" /* out03, w02 * r05 */
"sxtl v2.8h, v0.8b\n" /* r0, cvt to int16 */
/* inr3 -> outr1 */
"ldr d1, [%[r2]], #4\n" /* load r2 */
"stp q16, q17, [%[ptr_out0]], #32\n"
"smlal v20.4s, v14.4h, v3.h[0]\n" /* out10, w20 * r30 */
"smlal v21.4s, v14.4h, v3.h[1]\n" /* out11, w20 * r31 */
"smlal v22.4s, v14.4h, v3.h[2]\n" /* out12, w20 * r32 */
"smlal v23.4s, v14.4h, v3.h[3]\n" /* out13, w20 * r33 */
"stp q18, q19, [%[ptr_out0]], #32\n"
"ldp q16, q17, [%[ptr_out0]]\n"
"smlal2 v20.4s, v14.8h, v3.h[1]\n" /* out10, w21 * r31 */
"smlal2 v21.4s, v14.8h, v3.h[2]\n" /* out11, w21 * r32 */
"smlal2 v22.4s, v14.8h, v3.h[3]\n" /* out12, w21 * r33 */
"smlal2 v23.4s, v14.8h, v3.h[4]\n" /* out13, w21 * r34 */
"ldp q18, q19, [%[ptr_out0], #32]\n"
"smlal v20.4s, v15.4h, v3.h[2]\n" /* out10, w22 * r32 */
"smlal v21.4s, v15.4h, v3.h[3]\n" /* out11, w22 * r33 */
"smlal v22.4s, v15.4h, v3.h[4]\n" /* out12, w22 * r34 */
"smlal v23.4s, v15.4h, v3.h[5]\n" /* out13, w22 * r35 */
"sxtl v3.8h, v1.8b\n" /* r0, cvt to int16 */
"subs %w[cnt], %w[cnt], #1\n"
"stp q20, q21, [%[ptr_out1]], #32\n"
"stp q22, q23, [%[ptr_out1]], #32\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[wc0] "+r"(ptr_wc0),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
:
: "cc", "memory", "v0", "v1", "v2", "v3", "v4",
"v5", "v6", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20","v21", "v22", "v23"
);
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < chout; c += hout_c_block) { // 4
#ifdef ARM_WITH_OMP
int* pre_out =
reinterpret_cast<int*>(pre_din + (pre_in_size + 3) / 4 * 4) +
omp_get_thread_num() * pre_out_size;
#else #else
int* pre_out = asm volatile(
reinterpret_cast<int*>(pre_din + (pre_in_size + 3) / 4 * 4); "vld1.32 {d0-d3}, [%[wc0]]!\n"
#endif "vld1.32 {d4}, [%[wc0]]!\n"
// printf("ptr_zero_int: %x, ptr_zero: %x, ptr_write: %x, pre_din: %x, "vmovl.s8 q3, d0\n" /* q3 = w0, w1 */
// pre_out: %x \n", ptr_zero_int, ptr_zero, ptr_write, pre_din, "vmovl.s8 q4, d1\n" /* q4 = w2 ,w3 */
// pre_out); "vmovl.s8 q5, d2\n" /* q5 = w4, w5 */
const signed char* inr0 = pre_din; "vmovl.s8 q6, d3\n" /* q6 = w6, w7 */
const signed char* inr1 = inr0 + in_len; "vmovl.s8 q7, d4\n" /* q7 = w8 */
const signed char* inr2 = inr1 + in_len; "vld1.32 d0, [%[r1]]\n"
const signed char* inr3 = inr2 + in_len; "vld1.32 d1, [%[r2]]\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]!\n"
const signed char* wc0 = "vld1.32 {d20-d23}, [%[ptr_out0]]\n"
static_cast<const signed char*>(weights) + c * w_stride; "vmovl.s8 q1, d0\n"
"vmovl.s8 q2, d1\n"
const int* bias_ptr = ptr_zero; "1:\n"
if (flag_bias) { /* inr1 -> outr0, outr1 */
bias_ptr = static_cast<const int*>(bias) + c; "vld1.32 {d24-d27}, [%[ptr_out1]]!\n"
} "vld1.32 d0, [%[r0]]\n" /* load r0 */
// hout_r_block * wout_round * hout_c_block "vmlal.s16 q8, d9, d2[0]\n" /* out00, w10 * r10 */
fill_packed_bias_nxmw_int8( "vmlal.s16 q9, d9, d2[1]\n" /* out01, w10 * r11 */
bias_ptr, pre_out, hout_c_block, hout_r_block, wout_round); "vmlal.s16 q10, d9, d2[2]\n" /* out02, w10 * r12 */
"vmlal.s16 q11, d9, d2[3]\n" /* out03, w10 * r13 */
for (int i = 0; i < chin; ++i) { "vld1.32 {d28-d31}, [%[ptr_out1]]\n"
const signed char* r0 = inr0; "vmlal.s16 q8, d10, d2[1]\n" /* out00, w11 * r11 */
const signed char* r1 = inr1; "vmlal.s16 q9, d10, d2[2]\n" /* out01, w11 * r12 */
const signed char* r2 = inr2; "vmlal.s16 q10, d10, d2[3]\n" /* out02, w11 * r13 */
const signed char* r3 = inr3; "vmlal.s16 q11, d10, d3[0]\n" /* out03, w11 * r14 */
"sub %[ptr_out0], %[ptr_out0], #32\n"
int* ptr_out0 = pre_out; "vmlal.s16 q8, d11, d2[2]\n" /* out00, w12 * r12 */
int* ptr_out1 = pre_out + size_out; "vmlal.s16 q9, d11, d2[3]\n" /* out01, w12 * r13 */
"vmlal.s16 q10, d11, d3[0]\n" /* out02, w12 * r14 */
int cnt = w_loop; "vmlal.s16 q11, d11, d3[1]\n" /* out03, w12 * r15 */
const signed char* ptr_wc = wc0; "vmlal.s16 q12, d6, d2[0]\n" /* out10, w00 * r10 */
"vmlal.s16 q13, d6, d2[1]\n" /* out11, w00 * r11 */
asm volatile( "vmlal.s16 q14, d6, d2[2]\n" /* out12, w00 * r12 */
"vld1.s8 {d0-d3}, [%[wc0]]! \n" /* wc0, wc1, wc2, wc3, wc4, "vmlal.s16 q15, d6, d2[3]\n" /* out13, w00 * r13 */
wc5, wc6, wc7*/ "add %[r1], %[r1], #4\n"
"vld1.s8 {d4}, [%[wc0]]! \n" /* wc8 */ "vmlal.s16 q12, d7, d2[1]\n" /* out10, w01 * r11 */
"vmovl.s8 q3, d0 \n" /* q3 = w0, w1 */ "vmlal.s16 q13, d7, d2[2]\n" /* out11, w01 * r12 */
"vmovl.s8 q4, d1 \n" /* q4 = w2 ,w3 */ "vmlal.s16 q14, d7, d2[3]\n" /* out12, w01 * r13 */
"vmovl.s8 q5, d2 \n" /* q5 = w4, w5 */ "vmlal.s16 q15, d7, d3[0]\n" /* out13, w01 * r14 */
"vmovl.s8 q6, d3 \n" /* q6 = w6, w7 */ "sub %[ptr_out1], %[ptr_out1], #32\n"
"vmovl.s8 q7, d4 \n" /* q7 = w8 */ "vmlal.s16 q12, d8, d2[2]\n" /* out10, w02 * r12 */
"vmlal.s16 q13, d8, d2[3]\n" /* out11, w02 * r13 */
"1: \n" /* main loop*/ "vmlal.s16 q14, d8, d3[0]\n" /* out12, w02 * r14 */
"vld1.s32 {d0}, [%[r0]] \n" /* load data din0-dinn7*/ "vmlal.s16 q15, d8, d3[1]\n" /* out13, w02 * r15 */
"vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ "vmovl.s8 q1, d0\n" /* r0, cvt to int16 */
/*output 1st row*/ /* inr2 -> outr0, outr1 */
"vmull.s16 q8, d6, d0[0] \n" /* q8 = w0 * r0[0] */ "vld1.32 d1, [%[r3]]\n" /* load r3 */
"vmull.s16 q9, d6, d0[1] \n" /* q9 = w0 * r0[2] */ "vmlal.s16 q8, d12, d4[0]\n" /* out00, w20 * r20 */
"vmull.s16 q10, d6, d0[2] \n" /* q10 = w0 * r0[4] */ "vmlal.s16 q9, d12, d4[1]\n" /* out01, w20 * r21 */
"vmull.s16 q11, d6, d0[3] \n" /* q11 = w0 * r0[6] */ "vmlal.s16 q10, d12, d4[2]\n" /* out02, w20 * r22 */
"vmlal.s16 q11, d12, d4[3]\n" /* out03, w20 * r23 */
"add %[r0], #4 \n" "add %[r2], %[r2], #4\n"
"vmlal.s16 q12, d9, d4[0]\n" /* out10, w10 * r20 */
/*output 1st row*/ "vmlal.s16 q13, d9, d4[1]\n" /* out11, w10 * r21 */
"vmlal.s16 q8, d7, d0[1] \n" /* q8 = w1 * r0[1] */ "vmlal.s16 q14, d9, d4[2]\n" /* out12, w10 * r22 */
"vmlal.s16 q9, d7, d0[2] \n" /* q9 = w1 * r0[2] */ "vmlal.s16 q15, d9, d4[3]\n" /* out13, w10 * r23 */
"vmlal.s16 q10, d7, d0[3] \n" /* q10 = w1 * r0[3] */ "vmlal.s16 q8, d13, d4[1]\n" /* out00, w21 * r21 */
"vmlal.s16 q11, d7, d1[0] \n" /* q11 = w1 * r0[4] */ "vmlal.s16 q9, d13, d4[2]\n" /* out01, w21 * r22 */
"vmlal.s16 q10, d13, d4[3]\n" /* out02, w21 * r23 */
"vld1.s32 {d2}, [%[r1]] \n" /* load input r1 -> d2 */ "vmlal.s16 q11, d13, d5[0]\n" /* out03, w21 * r24 */
"vmovl.s8 q1, d2 \n" /* movl d2 -> q1 */ "add %[r0], %[r0], #4\n"
"vmlal.s16 q12, d10, d4[1]\n" /* out10, w11 * r21 */
/*output 1st row*/ "vmlal.s16 q13, d10, d4[2]\n" /* out11, w11 * r22 */
"vmlal.s16 q8, d8, d0[2] \n" /* q8 = w2 * r0[2] */ "vmlal.s16 q14, d10, d4[3]\n" /* out12, w11 * r23 */
"vmlal.s16 q9, d8, d0[3] \n" /* q9 = w2 * r0[3] */ "vmlal.s16 q15, d10, d5[0]\n" /* out13, w11 * r24 */
"vmlal.s16 q10, d8, d1[0] \n" /* q10 = w2 * r0[4] */ "vmlal.s16 q8, d14, d4[2]\n" /* out00, w22 * r22 */
"vmlal.s16 q11, d8, d1[1] \n" /* q11 = w2 * r0[5] */ "vmlal.s16 q9, d14, d4[3]\n" /* out01, w22 * r23 */
"vmlal.s16 q10, d14, d5[0]\n" /* out02, w22 * r24 */
/*output 1st row*/ "vmlal.s16 q11, d14, d5[1]\n" /* out03, w22 * r25 */
"vmlal.s16 q8, d9, d2[0] \n" /* */ "add %[r3], %[r3], #4\n"
"vmlal.s16 q9, d9, d2[1] \n" /* */ "vmlal.s16 q12, d11, d4[2]\n" /* out10, w12 * r22 */
"vmlal.s16 q10, d9, d2[2] \n" /* */ "vmlal.s16 q13, d11, d4[3]\n" /* out11, w12 * r23 */
"vmlal.s16 q11, d9, d2[3] \n" /* */ "vmlal.s16 q14, d11, d5[0]\n" /* out12, w12 * r24 */
"vmlal.s16 q15, d11, d5[1]\n" /* out13, w12 * r25 */
"add %[r1], #4 \n" "vmovl.s8 q2, d1\n" /* r3, cvt to int16 */
/* inr0 -> outr0 */
/*output 1st row*/ "vld1.32 d0, [%[r1]]\n" /* load r1 */
"vmlal.s16 q8, d10, d2[1] \n" /* */ "vmlal.s16 q8, d6, d2[0]\n" /* out00, w00 * r00 */
"vmlal.s16 q9, d10, d2[2] \n" /* */ "vmlal.s16 q9, d6, d2[1]\n" /* out01, w00 * r01 */
"vmlal.s16 q10, d10, d2[3] \n" /* */ "vmlal.s16 q10, d6, d2[2]\n" /* out02, w00 * r02 */
"vmlal.s16 q11, d10, d3[0] \n" /* */ "vmlal.s16 q11, d6, d2[3]\n" /* out03, w00 * r03 */
"vmlal.s16 q8, d7, d2[1]\n" /* out00, w01 * r01 */
/*output 1st row*/ "vmlal.s16 q9, d7, d2[2]\n" /* out01, w01 * r02 */
"vmlal.s16 q8, d11, d2[2] \n" /* */ "vmlal.s16 q10, d7, d2[3]\n" /* out02, w01 * r03 */
"vmlal.s16 q9, d11, d2[3] \n" /* */ "vmlal.s16 q11, d7, d3[0]\n" /* out03, w01 * r04 */
"vmlal.s16 q10, d11, d3[0] \n" /* */ "vmlal.s16 q8, d8, d2[2]\n" /* out00, w02 * r02 */
"vmlal.s16 q11, d11, d3[1] \n" /* */ "vmlal.s16 q9, d8, d2[3]\n" /* out01, w02 * r03 */
"vmlal.s16 q10, d8, d3[0]\n" /* out02, w02 * r04 */
/*output 2rd row*/ "vmlal.s16 q11, d8, d3[1]\n" /* out03, w02 * r05 */
"vmull.s16 q12, d6, d2[0] \n" /* */ "vmovl.s8 q1, d0\n" /* r1, cvt to int16 */
"vmull.s16 q13, d6, d2[1] \n" /* */ /* inr3 -> outr1 */
"vmull.s16 q14, d6, d2[2] \n" /* */ "vld1.32 {d1}, [%[r2]]\n" /* load r2 */
"vmull.s16 q15, d6, d2[3] \n" /* */ "vst1.32 {d16-d19}, [%[ptr_out0]]!\n"
"vmlal.s16 q12, d12, d4[0]\n" /* out10, w20 * r30 */
"vld1.s32 {d0}, [%[r2]] \n" /* load input r2 -> d2 */ "vmlal.s16 q13, d12, d4[1]\n" /* out11, w20 * r31 */
"vmovl.s8 q0, d0 \n" /* movl d2 -> q1 */ "vmlal.s16 q14, d12, d4[2]\n" /* out12, w20 * r32 */
"vmlal.s16 q15, d12, d4[3]\n" /* out13, w20 * r33 */
/*output 2rd row*/ "vst1.32 {d20-d23}, [%[ptr_out0]]!\n"
"vmlal.s16 q12, d7, d2[1] \n" /* */ "vld1.32 {d16-d19}, [%[ptr_out0]]!\n"
"vmlal.s16 q13, d7, d2[2] \n" /* */ "vmlal.s16 q12, d13, d4[1]\n" /* out10, w21 * r31 */
"vmlal.s16 q14, d7, d2[3] \n" /* */ "vmlal.s16 q13, d13, d4[2]\n" /* out11, w21 * r32 */
"vmlal.s16 q15, d7, d3[0] \n" /* */ "vmlal.s16 q14, d13, d4[3]\n" /* out12, w21 * r33 */
"vmlal.s16 q15, d13, d5[0]\n" /* out13, w21 * r34 */
/*output 2rd row*/ "vld1.32 {d20-d23}, [%[ptr_out0]]\n"
"vmlal.s16 q12, d8, d2[2] \n" /* */ "vmlal.s16 q12, d14, d4[2]\n" /* out10, w22 * r32 */
"vmlal.s16 q13, d8, d2[3] \n" /* */ "vmlal.s16 q13, d14, d4[3]\n" /* out11, w22 * r33 */
"vmlal.s16 q14, d8, d3[0] \n" /* */ "vmlal.s16 q14, d14, d5[0]\n" /* out12, w22 * r34 */
"vmlal.s16 q15, d8, d3[1] \n" /* */ "vmlal.s16 q15, d14, d5[1]\n" /* out13, w22 * r35 */
"vmovl.s8 q2, d1\n" /* r2, cvt to int16 */
"add %[r2], #4 \n" "subs %[cnt], #1\n"
"vst1.32 {d24-d27}, [%[ptr_out1]]!\n"
/*output 1st row*/ "vst1.32 {d28-d31}, [%[ptr_out1]]!\n"
"vmlal.s16 q8, d12, d0[0] \n" /* */ "bne 1b\n"
"vmlal.s16 q9, d12, d0[1] \n" /* */ : [cnt] "+r"(cnt),
"vmlal.s16 q10, d12, d0[2] \n" /* */ [r0] "+r"(r0),
"vmlal.s16 q11, d12, d0[3] \n" /* */ [r1] "+r"(r1),
[r2] "+r"(r2),
/*output 1st row*/ [r3] "+r"(r3),
"vmlal.s16 q8, d13, d0[1] \n" /* */ [ptr_out0] "+r"(ptr_out0),
"vmlal.s16 q9, d13, d0[2] \n" /* */ [ptr_out1] "+r"(ptr_out1),
"vmlal.s16 q10, d13, d0[3] \n" /* */ [wc0] "+r"(ptr_wc0)
"vmlal.s16 q11, d13, d1[0] \n" /* */ :
: "cc", "memory", "q0", "q1", "q2", "q3",
"vld1.32 {d2-d5}, [%[ptr_out0]] \n" /* load ptr_out -> q, q "q4", "q5", "q6", "q7", "q8", "q9", "q10",
*/ "q11", "q12", "q13", "q14", "q15"
);
/*output 1st row*/ #endif // __aarch64__
"vmlal.s16 q8, d14, d0[2] \n" /* */ // clang-format on
"vmlal.s16 q9, d14, d0[3] \n" /* */ wc0 += 9 * hout_c_block;
"vmlal.s16 q10, d14, d1[0] \n" /* */ inr0 += win_round;
"vmlal.s16 q11, d14, d1[1] \n" /* */ inr1 += win_round;
inr2 += win_round;
/*load & store output 1st row*/ inr3 += win_round;
"vadd.s32 q1, q8, q1 \n" /* out[0] += q8 */ }
"vadd.s32 q2, q9, q2 \n" /* out[0] += q8 */ block_inr0 = block_inr2;
"vst1.s32 {d2-d5}, [%[ptr_out0]]! \n" block_inr1 = block_inr3;
block_inr2 = block_inr1 + in_len;
/*output 2rd row*/ block_inr3 = block_inr2 + in_len;
"vmlal.s16 q12, d9, d0[0] \n" /* */
"vmlal.s16 q13, d9, d0[1] \n" /* */
"vmlal.s16 q14, d9, d0[2] \n" /* */
"vmlal.s16 q15, d9, d0[3] \n" /* */
"vld1.32 {d2-d5}, [%[ptr_out0]] \n" /* load ptr_out -> q2, q3
*/
/*output 2rd row */
"vmlal.s16 q12, d10, d0[1] \n" /* */
"vmlal.s16 q13, d10, d0[2] \n" /* */
"vadd.s32 q1, q10, q1 \n" /* out[0] += q */
"vadd.s32 q2, q11, q2 \n" /* out[1] += q */
"vmlal.s16 q14, d10, d0[3] \n" /* */
"vst1.s32 {d2-d5}, [%[ptr_out0]]! \n"
"vmlal.s16 q15, d10, d1[0] \n" /* */
/*output 2rd row */
"vmlal.s16 q12, d11, d0[2] \n" /* */
"vmlal.s16 q13, d11, d0[3] \n" /* */
"vld1.s32 {d4}, [%[r3]] \n" /* load input r2 -> d2
*/
"vmovl.s8 q2, d4 \n" /* movl d2 -> q2 */
"vmlal.s16 q14, d11, d1[0] \n" /* */
"vmlal.s16 q15, d11, d1[1] \n" /* */
"add %[r3], #4 \n"
/*output 2rd row */
"vmlal.s16 q12, d12, d4[0] \n" /* */
"vmlal.s16 q13, d12, d4[1] \n" /* */
"vmlal.s16 q14, d12, d4[2] \n" /* */
"vmlal.s16 q15, d12, d4[3] \n" /* */
"vld1.32 {d0-d3}, [%[ptr_out1]] \n" /* */
/*output 2rd row */
"vmlal.s16 q12, d13, d4[1] \n" /* */
"vmlal.s16 q13, d13, d4[2] \n" /* */
"vmlal.s16 q14, d13, d4[3] \n" /* */
"vmlal.s16 q15, d13, d5[0] \n" /* */
"subs %[cnt], #1 \n"
/*output 2rd row */
"vmlal.s16 q12, d14, d4[2] \n" /* */
"vmlal.s16 q13, d14, d4[3] \n" /* */
"vmlal.s16 q14, d14, d5[0] \n" /* */
"vmlal.s16 q15, d14, d5[1] \n" /* */
/*output 2rd row*/
"vadd.s32 q0, q12, q0 \n" /* */
"vadd.s32 q1, q13, q1 \n" /* */
"vst1.s32 {d0-d3}, [%[ptr_out1]]! \n"
"vld1.32 {d0-d3}, [%[ptr_out1]] \n" /* */
"vadd.s32 q0, q14, q0 \n" /* */
"vadd.s32 q1, q15, q1 \n" /* */
"vst1.s32 {d0-d3}, [%[ptr_out1]]! \n"
"bne 1b \n" /* jump to main loop*/
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1),
[wc0] "+r"(ptr_wc)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
wc0 += 9 * hout_c_block;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
inr3 += win_round;
} }
write_int32_nchwc4_to_nchw(pre_out,
if (out_type == PRECISION(kFloat)) { dout_batch,
write_to_output_c4_int32_1(pre_out,
reinterpret_cast<float*>(dout_batch),
hout_c_block,
hout_r_block,
c,
c + 4,
h,
h + 2,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<float*>(ptr_write),
&scale[c],
out_type);
} else if (out_type == PRECISION(kInt8)) {
write_to_output_c4_int32_1(pre_out,
dout_batch,
hout_c_block,
hout_r_block,
c,
c + 4,
h,
h + 2,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<signed char*>(ptr_write),
&scale[c],
out_type);
} else { // int32
write_to_output_c4_int32(pre_out,
reinterpret_cast<int*>(dout_batch),
hout_c_block,
hout_r_block,
c, c,
c + 4, c + 4,
h, h,
h + 2, h + hout_r_block,
0, 0,
wout_round, wout_round,
chout, chout,
hout, hout,
wout, wout,
flag_relu, flag_relu,
ptr_write); bias_local,
} flag_bias,
ptr_write,
scale + c);
} }
} }
} }
} }
#endif // __aarch64__ template void conv_3x3s1_direct_int8(const int8_t* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int8_t* weights,
const float* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx,
const float* scale);
template void conv_3x3s1_direct_int8(const int8_t* din,
int8_t* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int8_t* weights,
const float* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx,
const float* scale);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/operators/op_params.h"
#ifdef ARM_WITH_OMP
#include <omp.h>
#endif
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void conv_3x3s2_depthwise_fp32(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx) {
int threads = ctx->threads();
const int pad_h = param.paddings[0];
const int pad_w = param.paddings[1];
const int out_c_block = 4;
const int out_h_kernel = 1;
const int out_w_kernel = 4;
const int win_ext = ow * 2 + 1;
const int ow_round = ROUNDUP(ow, 4);
const int win_round = ROUNDUP(win_ext, 4);
const int hin_round = oh * 2 + 1;
const int prein_size = win_round * hin_round * out_c_block;
auto workspace_size =
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
/// get workspace
auto ptr_zero = ctx->workspace_data<float>();
memset(ptr_zero, 0, sizeof(float) * win_round);
float* ptr_write = ptr_zero + win_round;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int ws = -pad_w;
int we = ws + win_round;
int hs = -pad_h;
int he = hs + hin_round;
int w_loop = ow_round / 4;
auto remain = w_loop * 4 - ow;
bool flag_remain = remain > 0;
remain = 4 - remain;
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc; c += out_c_block) {
#ifdef ARM_WITH_OMP
float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size;
#else
float* pre_din = ptr_write + ow_round;
#endif
/// const array size
prepack_input_nxwc4_dw(
din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero);
const float* weight_c = weights + c * 9; // kernel_w * kernel_h
float* dout_c00 = dout_batch + c * size_out_channel;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
#ifdef __aarch64__
float32x4_t w0 = vld1q_f32(weight_c); // w0, v23
float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24
float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25
float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26
float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27
float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28
float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29
float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30
float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31
#endif
for (int h = 0; h < oh; h += out_h_kernel) {
float* outc0 = dout_c00 + h * ow;
float* outc1 = outc0 + size_out_channel;
float* outc2 = outc1 + size_out_channel;
float* outc3 = outc2 + size_out_channel;
const float* inr0 = pre_din + h * 2 * row_len;
const float* inr1 = inr0 + row_len;
const float* inr2 = inr1 + row_len;
if (c + out_c_block > oc) {
switch (c + out_c_block - oc) {
case 3:
outc1 = ptr_write;
case 2:
outc2 = ptr_write;
case 1:
outc3 = ptr_write;
default:
break;
}
}
auto c0 = outc0;
auto c1 = outc1;
auto c2 = outc2;
auto c3 = outc3;
float pre_out[16];
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
if (flag_mask) {
c0 = outc0;
c1 = outc1;
c2 = outc2;
c3 = outc3;
outc0 = pre_out;
outc1 = pre_out + 4;
outc2 = pre_out + 8;
outc3 = pre_out + 12;
}
// clang-format off
#ifdef __aarch64__
asm volatile(
"ldr q8, [%[bias]]\n" /* load bias */
"ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/
"and v19.16b, v8.16b, v8.16b\n"
"ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/
"and v20.16b, v8.16b, v8.16b\n"
"ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/
"and v21.16b, v8.16b, v8.16b\n"
"ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/
"and v22.16b, v8.16b, v8.16b\n"
"ldr q8, [%[inr0]]\n" /* load input r0*/
/* r0 mul w0-w2, get out */
"fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/
"fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/
"fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/
"fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/
"fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/
"ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/
"fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/
"fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/
"fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/
"fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/
"ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/
"fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/
"ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/
"fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/
"ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/
"fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/
"ldr q8, [%[inr1]]\n" /* load input r1*/
/* r1, mul w3-w5, get out */
"fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/
"fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/
"fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/
"fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/
"ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/
"fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/
"fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/
"fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/
"fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/
"ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/
"fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/
"ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/
"fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/
"ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/
"fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/
"ldr q8, [%[inr2]]\n" /* load input r2*/
/* r2, mul w6-w8, get out r0, r1 */
"fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/
"fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/
"fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/
"fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/
"fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/
"fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/
"fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/
"fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/
"fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/
"fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/
"fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/
"fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/
/* transpose */
"trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/
"trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/
"trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/
"trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/
"trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/
"trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/
"trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/
"trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
/* relu */
"cbz %w[flag_relu], 0f\n" /* skip relu*/
"movi v0.4s, #0\n" /* for relu */
"fmax v19.4s, v19.4s, v0.4s\n"
"fmax v20.4s, v20.4s, v0.4s\n"
"fmax v21.4s, v21.4s, v0.4s\n"
"fmax v22.4s, v22.4s, v0.4s\n"
/* save result */
"0:\n"
"str q19, [%[outc0]], #16\n"
"str q20, [%[outc1]], #16\n"
"str q21, [%[outc2]], #16\n"
"str q22, [%[outc3]], #16\n"
:[inr0] "+r"(inr0), [inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0]"+r"(outc0), [outc1]"+r"(outc1),
[outc2]"+r"(outc2), [outc3]"+r"(outc3)
:[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2),
[w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5),
[w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8),
[bias] "r" (bias_local), [flag_relu]"r"(flag_relu)
: "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8", "v19","v20","v21","v22"
);
#else
asm volatile(
/* fill with bias */
"vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */
/* load weights */
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/
"vand.i32 q12, q8, q8\n"
"vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/
"vand.i32 q13, q8, q8\n"
"vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/
"vand.i32 q14, q8, q8\n"
"vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/
"vand.i32 q15, q8, q8\n"
"vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/
/* mul r0 with w0, w1, w2 */
"vmla.f32 q12, q9, q0 @ w0 * inr0\n"
"vmla.f32 q13, q9, q2 @ w0 * inr2\n"
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */
"vmla.f32 q14, q9, q4 @ w0 * inr4\n"
"vmla.f32 q15, q9, q6 @ w0 * inr6\n"
"vmla.f32 q12, q10, q1 @ w1 * inr1\n"
"vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n"
"vmla.f32 q13, q10, q3 @ w1 * inr3\n"
"vmla.f32 q14, q10, q5 @ w1 * inr5\n"
"vmla.f32 q15, q10, q7 @ w1 * inr7\n"
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */
"vmla.f32 q12, q11, q2 @ w2 * inr2\n"
"vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n"
"vmla.f32 q13, q11, q4 @ w2 * inr4\n"
"vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n"
"vmla.f32 q14, q11, q6 @ w2 * inr6\n"
"vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n"
"vmla.f32 q15, q11, q8 @ w2 * inr8\n"
/* mul r1 with w3, w4, w5 */
"vmla.f32 q12, q9, q0 @ w3 * inr0\n"
"vmla.f32 q13, q9, q2 @ w3 * inr2\n"
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */
"vmla.f32 q14, q9, q4 @ w3 * inr4\n"
"vmla.f32 q15, q9, q6 @ w3 * inr6\n"
"vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/
"vmla.f32 q12, q10, q1 @ w4 * inr1\n"
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n"
"vmla.f32 q13, q10, q3 @ w4 * inr3\n"
"vmla.f32 q14, q10, q5 @ w4 * inr5\n"
"vmla.f32 q15, q10, q7 @ w4 * inr7\n"
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */
"vmla.f32 q12, q11, q2 @ w5 * inr2\n"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n"
"vmla.f32 q13, q11, q4 @ w5 * inr4\n"
"vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n"
"vmla.f32 q14, q11, q6 @ w5 * inr6\n"
"vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n"
"vmla.f32 q15, q11, q8 @ w5 * inr8\n"
/* mul r2 with w6, w7, w8 */
"vmla.f32 q12, q9, q0 @ w6 * inr0\n"
"vmla.f32 q13, q9, q2 @ w6 * inr2\n"
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */
"vmla.f32 q14, q9, q4 @ w6 * inr4\n"
"vmla.f32 q15, q9, q6 @ w6 * inr6\n"
"vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/
"vmla.f32 q12, q10, q1 @ w7 * inr1\n"
"vmla.f32 q13, q10, q3 @ w7 * inr3\n"
"vmla.f32 q14, q10, q5 @ w7 * inr5\n"
"vmla.f32 q15, q10, q7 @ w7 * inr7\n"
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n"
"vmla.f32 q12, q11, q2 @ w8 * inr2\n"
"vmla.f32 q13, q11, q4 @ w8 * inr4\n"
"vmla.f32 q14, q11, q6 @ w8 * inr6\n"
"vmla.f32 q15, q11, q8 @ w8 * inr8\n"
/* transpose */
"vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/
"vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/
"vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/
"vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/
"cmp %[flag_relu], #0\n"
"beq 0f\n" /* skip relu*/
"vmov.u32 q0, #0\n"
"vmax.f32 q12, q12, q0\n"
"vmax.f32 q13, q13, q0\n"
"vmax.f32 q14, q14, q0\n"
"vmax.f32 q15, q15, q0\n"
"0:\n"
"vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/
"vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/
"vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/
"vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/
:[r0] "+r"(inr0), [r1] "+r"(inr1),
[r2] "+r"(inr2), [wc0] "+r" (weight_c),
[outc0]"+r"(outc0), [outc1]"+r"(outc1),
[outc2]"+r"(outc2), [outc3]"+r"(outc3)
:[bias] "r" (bias_local),
[flag_relu]"r"(flag_relu)
:"cc", "memory",
"q0","q1","q2","q3","q4","q5","q6","q7",
"q8", "q9","q10","q11","q12","q13","q14","q15"
);
#endif // __arch64__
// clang-format off
if (flag_mask) {
for (int i = 0; i < remain; ++i) {
c0[i] = pre_out[i];
c1[i] = pre_out[i + 4];
c2[i] = pre_out[i + 8];
c3[i] = pre_out[i + 12];
}
}
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_depthwise.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/operators/op_params.h"
#ifdef ARM_WITH_OMP
#include <omp.h>
#endif
namespace paddle {
namespace lite {
namespace arm {
namespace math {
#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
template <typename Dtype>
void conv_depthwise_3x3s2_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx) {
const int threads = ctx->threads();
int llc_size = ctx->llc_size() / 4;
const int hout_c_block = 8;
const int hout_r_kernel = 1;
const int wout_block = 4;
const int wout_round = ROUNDUP(wout, wout_block);
const int win_round = wout_round * 2 /*stride*/ + 1;
//! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2
int hout_r_block =
(llc_size - 2 * win_round * threads) /
(2 * win_round * threads + hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block;
const int hin_r_block = hout_r_block * 2 /*stride*/ + 1;
auto tmp_work_space = ctx->workspace_data<int8_t>();
int8_t ptr_zero[win_round]; // NOLINT
memset(ptr_zero, 0, sizeof(int8_t) * win_round);
Dtype ptr_write[wout_round]; // NOLINT
int in_len = win_round * hout_c_block;
int pre_in_size = hin_r_block * in_len;
pre_in_size = ROUNDUP(pre_in_size, 4);
int pre_out_size = hout_c_block * hout_r_block * wout_round;
int8_t* tmp_din = tmp_work_space;
int size_in_channel = win * hin;
int size_out_channel = wout * hout;
int w_stride = 9; // kernel_w * kernel_h;
int ws = -padw;
int we = ws + win_round;
int w_loop = wout_round / 4;
int chout = chin;
int out_row_stride = hout_c_block * wout_round;
for (int n = 0; n < num; ++n) {
const int8_t* din_batch = din + n * chin * size_in_channel;
int8_t* dout_batch = reinterpret_cast<int8_t*>(dout) +
n * chout * size_out_channel * sizeof(Dtype);
for (int h = 0; h < hout; h += hout_r_block) {
int h_kernel = hout_r_block;
if (h + hout_r_block > hout) {
h_kernel = hout - h;
}
int hs = h * 2 /*stride*/ - padh;
int he = hs + h_kernel * 2 /*stride*/ + 1;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < chout; c += hout_c_block) {
#ifdef ARM_WITH_OMP
int8_t* pre_din =
tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size * 4);
int32_t* pre_out = reinterpret_cast<int*>(pre_din + pre_in_size);
#else
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif
prepack_input_nxw_c8_int8(din_batch,
pre_din,
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len;
const int8_t* weight_c = weights + c * w_stride;
float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
bias_local[4] = bias[c + 4];
bias_local[5] = bias[c + 5];
bias_local[6] = bias[c + 6];
bias_local[7] = bias[c + 7];
}
#ifdef __aarch64__
int8x8_t vw0 = vld1_s8(weight_c);
int8x8_t vw1 = vld1_s8(weight_c + 8);
int8x8_t vw2 = vld1_s8(weight_c + 16);
int8x8_t vw3 = vld1_s8(weight_c + 24);
int8x8_t vw4 = vld1_s8(weight_c + 32);
int8x8_t vw5 = vld1_s8(weight_c + 40);
int8x8_t vw6 = vld1_s8(weight_c + 48);
int8x8_t vw7 = vld1_s8(weight_c + 56);
int8x8_t vw8 = vld1_s8(weight_c + 64);
#endif
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
int cnt = w_loop;
const int8_t* inr0 = block_inr0;
const int8_t* inr1 = block_inr1;
const int8_t* inr2 = block_inr2;
int32_t* ptr_out0 = pre_out + hk * out_row_stride;
#ifdef __aarch64__
asm volatile(
"ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n"
"ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r0]], #32\n"
"1:\n"
/* inr0 -> outr0 */
"smull v20.8h, v0.8b, %[w0].8b\n" /* int16, out0 */
"smull v21.8h, v2.8b, %[w0].8b\n" /* int16, out1 */
"smull v22.8h, v4.8b, %[w0].8b\n" /* int16, out2 */
"smull v23.8h, v6.8b, %[w0].8b\n" /* int16, out3 */
"smlal v20.8h, v1.8b, %[w1].8b\n" /* int16, out0 */
"smlal v21.8h, v3.8b, %[w1].8b\n" /* int16, out1 */
"smlal v22.8h, v5.8b, %[w1].8b\n" /* int16, out2 */
"smlal v23.8h, v7.8b, %[w1].8b\n" /* int16, out3 */
"ldr d8, [%[r0]]\n" /* load r0, 8 */
"ldp d0, d1, [%[r1]], #16\n" /* load r1, 0,1 */
"sxtl v24.4s, v20.4h\n"
"sxtl2 v25.4s, v20.8h\n"
"smull v20.8h, v2.8b, %[w2].8b\n" /* int16, out0 */
"ldp d2, d3, [%[r1]], #16\n" /* load r1, 2,3 */
"sxtl v26.4s, v21.4h\n"
"sxtl2 v27.4s, v21.8h\n"
"smull v21.8h, v4.8b, %[w2].8b\n" /* int16, out1 */
"ldp d4, d5, [%[r1]], #16\n" /* load r1, 4,5 */
"sxtl v28.4s, v22.4h\n"
"sxtl2 v29.4s, v22.8h\n"
"smull v22.8h, v6.8b, %[w2].8b\n" /* int16, out2 */
"ldp d6, d7, [%[r1]], #16\n" /* load r1, 6,7 */
"sxtl v30.4s, v23.4h\n"
"sxtl2 v31.4s, v23.8h\n"
"smull v23.8h, v8.8b, %[w2].8b\n" /* int16, out3 */
"smlal v20.8h, v0.8b, %[w3].8b\n" /* int16, out0 */
"smlal v21.8h, v2.8b, %[w3].8b\n" /* int16, out1 */
"smlal v22.8h, v4.8b, %[w3].8b\n" /* int16, out2 */
"smlal v23.8h, v6.8b, %[w3].8b\n" /* int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n"
"saddw2 v25.4s, v25.4s, v20.8h\n"
"saddw v26.4s, v26.4s, v21.4h\n"
"saddw2 v27.4s, v27.4s, v21.8h\n"
"ldr d8, [%[r1]]\n" /* load r1, 8 */
"saddw v28.4s, v28.4s, v22.4h\n"
"saddw2 v29.4s, v29.4s, v22.8h\n"
"saddw v30.4s, v30.4s, v23.4h\n"
"saddw2 v31.4s, v31.4s, v23.8h\n"
"smull v20.8h, v1.8b, %[w4].8b\n" /* int16, out0 */
"smull v21.8h, v3.8b, %[w4].8b\n" /* int16, out1 */
"smull v22.8h, v5.8b, %[w4].8b\n" /* int16, out1 */
"smull v23.8h, v7.8b, %[w4].8b\n" /* int16, out1 */
"ldp d0, d1, [%[r2]], #16\n" /* load r2, 0,1 */
"smlal v20.8h, v2.8b, %[w5].8b\n" /* int16, out0 */
"smlal v21.8h, v4.8b, %[w5].8b\n" /* int16, out1 */
"ldp d2, d3, [%[r2]], #16\n" /* load r2, 2,3 */
"smlal v22.8h, v6.8b, %[w5].8b\n" /* int16, out2 */
"smlal v23.8h, v8.8b, %[w5].8b\n" /* int16, out3 */
"ldp d4, d5, [%[r2]], #16\n" /* load r2, 4,5 */
"saddw v24.4s, v24.4s, v20.4h\n"
"saddw2 v25.4s, v25.4s, v20.8h\n"
"saddw v26.4s, v26.4s, v21.4h\n"
"saddw2 v27.4s, v27.4s, v21.8h\n"
"ldp d6, d7, [%[r2]], #16\n" /* load r2, 6,7 */
"saddw v28.4s, v28.4s, v22.4h\n"
"saddw2 v29.4s, v29.4s, v22.8h\n"
"saddw v30.4s, v30.4s, v23.4h\n"
"saddw2 v31.4s, v31.4s, v23.8h\n"
"smull v20.8h, v0.8b, %[w6].8b\n" /* int16, out0 */
"smull v21.8h, v2.8b, %[w6].8b\n" /* int16, out1 */
"smull v22.8h, v4.8b, %[w6].8b\n" /* int16, out1 */
"smull v23.8h, v6.8b, %[w6].8b\n" /* int16, out1 */
"smlal v20.8h, v1.8b, %[w7].8b\n" /* int16, out0 */
"smlal v21.8h, v3.8b, %[w7].8b\n" /* int16, out1 */
"smlal v22.8h, v5.8b, %[w7].8b\n" /* int16, out1 */
"smlal v23.8h, v7.8b, %[w7].8b\n" /* int16, out1 */
"ldp d0, d1, [%[r0]], #16\n" /* load r0, 0,1 */
"saddw v24.4s, v24.4s, v20.4h\n"
"saddw2 v25.4s, v25.4s, v20.8h\n"
"saddw v26.4s, v26.4s, v21.4h\n"
"saddw2 v27.4s, v27.4s, v21.8h\n"
"ldr d8, [%[r2]]\n" /* load r2 */
"saddw v28.4s, v28.4s, v22.4h\n"
"saddw2 v29.4s, v29.4s, v22.8h\n"
"saddw v30.4s, v30.4s, v23.4h\n"
"saddw2 v31.4s, v31.4s, v23.8h\n"
"smull v20.8h, v2.8b, %[w8].8b\n" /* int16, out0 */
"smull v21.8h, v4.8b, %[w8].8b\n" /* int16, out1 */
"ldp d2, d3, [%[r0]], #16\n" /* load r0, 2,3 */
"smull v22.8h, v6.8b, %[w8].8b\n" /* int16, out1 */
"smull v23.8h, v8.8b, %[w8].8b\n" /* int16, out1 */
"ldp d4, d5, [%[r0]], #16\n" /* load r0, 5 */
"saddw v24.4s, v24.4s, v20.4h\n"
"saddw2 v25.4s, v25.4s, v20.8h\n"
"saddw v26.4s, v26.4s, v21.4h\n"
"saddw2 v27.4s, v27.4s, v21.8h\n"
"ldp d6, d7, [%[r0]], #16\n" /* load r0, 6 */
"stp q24, q25, [%[ptr_out0]], #32\n"
"saddw v28.4s, v28.4s, v22.4h\n"
"saddw2 v29.4s, v29.4s, v22.8h\n"
"stp q26, q27, [%[ptr_out0]], #32\n"
"saddw v30.4s, v30.4s, v23.4h\n"
"saddw2 v31.4s, v31.4s, v23.8h\n"
"subs %w[cnt], %w[cnt], #1\n"
"stp q28, q29, [%[ptr_out0]], #32\n"
"stp q30, q31, [%[ptr_out0]], #32\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[ptr_out0] "+r"(ptr_out0)
: [w0] "w"(vw0),
[w1] "w"(vw1),
[w2] "w"(vw2),
[w3] "w"(vw3),
[w4] "w"(vw4),
[w5] "w"(vw5),
[w6] "w"(vw6),
[w7] "w"(vw7),
[w8] "w"(vw8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"v26",
"v27",
"v28",
"v29",
"v30",
"v31"
);
#else
auto wptr = weight_c;
asm volatile(
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */
"vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 4-5 */
"1:\n"
/* inr0 -> outr0 */
"vmull.s8 q4, d1, d7\n" /* int16, out0 */
"vld1.32 {d1}, [%[r0]]!\n" /* load r0, 6 */
"vmull.s8 q5, d3, d7\n" /* int16, out1 */
"vld1.32 {d3}, [%[r0]]!\n" /* load r0, 7 */
"vmull.s8 q6, d5, d7\n" /* int16, out2 */
"vld1.32 {d5}, [%[r0]]\n" /* load r0, 8 */
"vmull.s8 q7, d1, d6\n" /* int16, out0 */
"vmlal.s8 q4, d0, d6\n" /* int16, out3 */
"vmlal.s8 q5, d2, d6\n" /* int16, out1 */
"vmlal.s8 q6, d4, d6\n" /* int16, out2 */
"vmlal.s8 q7, d3, d7\n" /* int16, out3 */
"vmovl.s16 q8, d8\n"
"vmovl.s16 q9, d9\n"
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w2-w3 */
"vmovl.s16 q10, d10\n"
"vmovl.s16 q11, d11\n"
"vmovl.s16 q12, d12\n"
"vmovl.s16 q13, d13\n"
"vmovl.s16 q14, d14\n"
"vmovl.s16 q15, d15\n"
"vmull.s8 q4, d2, d6\n" /* int16, out0 */
"vmull.s8 q6, d1, d6\n" /* int16, out2 */
"vld1.32 {d0-d3}, [%[r1]]!\n" /* load r1, 0-3 */
"vmull.s8 q5, d4, d6\n" /* int16, out1 */
"vmull.s8 q7, d5, d6\n" /* int16, out3 */
"vld1.32 {d4-d5}, [%[r1]]!\n" /* load r1, 4,5 */
/* inr1 -> outr0 */
"vmlal.s8 q4, d0, d7\n" /* int16, out0 */
"vld1.32 {d0}, [%[r1]]!\n" /* load r1, 6 */
"vmlal.s8 q5, d2, d7\n" /* int16, out1 */
"vmlal.s8 q6, d4, d7\n" /* int16, out2 */
"vaddw.s16 q8, q8, d8\n"
"vaddw.s16 q9, q9, d9\n"
"vmlal.s8 q7, d0, d7\n" /* int16, out3 */
"vaddw.s16 q10, q10, d10\n"
"vaddw.s16 q11, q11, d11\n"
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w4-w5 */
"vaddw.s16 q12, q12, d12\n"
"vaddw.s16 q13, q13, d13\n"
"vaddw.s16 q14, q14, d14\n"
"vaddw.s16 q15, q15, d15\n"
"vmull.s8 q4, d1, d6\n" /* int16, out0 */
"vld1.32 {d1}, [%[r1]]!\n" /* load r1, 7 */
"vmull.s8 q5, d3, d6\n" /* int16, out1 */
"vld1.32 {d3}, [%[r1]]\n" /* load r1, 8 */
"vmull.s8 q6, d5, d6\n" /* int16, out2 */
"vmull.s8 q7, d1, d6\n" /* int16, out3 */
"vmlal.s8 q4, d2, d7\n" /* int16, out0 */
"vmlal.s8 q5, d4, d7\n" /* int16, out2 */
"vmlal.s8 q6, d0, d7\n" /* int16, out1 */
"vmlal.s8 q7, d3, d7\n" /* int16, out3 */
"vld1.32 {d0-d3}, [%[r2]]!\n" /* load r2, 0-3 */
"vaddw.s16 q8, q8, d8\n"
"vaddw.s16 q9, q9, d9\n"
"vaddw.s16 q10, q10, d10\n"
"vaddw.s16 q11, q11, d11\n"
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w6-w7 */
"vaddw.s16 q12, q12, d12\n"
"vaddw.s16 q13, q13, d13\n"
"vaddw.s16 q14, q14, d14\n"
"vaddw.s16 q15, q15, d15\n"
"vld1.32 {d4-d5}, [%[r2]]!\n" /* load r2, 4-5 */
/* inr2 -> outr0 */
"vmull.s8 q4, d1, d7\n" /* int16, out0 */
"vld1.32 {d1}, [%[r2]]!\n" /* load r2, 6 */
"vmull.s8 q5, d3, d7\n" /* int16, out1 */
"vld1.32 {d3}, [%[r2]]!\n" /* load r2, 7 */
"vmull.s8 q6, d5, d7\n" /* int16, out2 */
"vld1.32 {d5}, [%[r2]]\n" /* load r2, 8 */
"vmull.s8 q7, d1, d6\n" /* int16, out3 */
"vmlal.s8 q4, d0, d6\n" /* int16, out0 */
"vmlal.s8 q5, d2, d6\n" /* int16, out1 */
"vmlal.s8 q6, d4, d6\n" /* int16, out2 */
"vmlal.s8 q7, d3, d7\n" /* int16, out3 */
"vld1.32 {d6}, [%[wptr]]!\n" /* load w8 */
"vaddw.s16 q8, q8, d8\n"
"vaddw.s16 q9, q9, d9\n"
"vaddw.s16 q10, q10, d10\n"
"vaddw.s16 q11, q11, d11\n"
"vaddw.s16 q12, q12, d12\n"
"vaddw.s16 q13, q13, d13\n"
"vaddw.s16 q14, q14, d14\n"
"vaddw.s16 q15, q15, d15\n"
"sub %[wptr], %[wptr], #72\n"
"vmull.s8 q4, d2, d6\n" /* int16, out0 */
"vmull.s8 q5, d4, d6\n" /* int16, out1 */
"vmull.s8 q6, d1, d6\n" /* int16, out2 */
"vmull.s8 q7, d5, d6\n" /* int16, out3 */
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */
"vaddw.s16 q8, q8, d8\n"
"vaddw.s16 q9, q9, d9\n"
"vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 4-5 */
"vaddw.s16 q10, q10, d10\n"
"vaddw.s16 q11, q11, d11\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]!\n"
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */
"vaddw.s16 q12, q12, d12\n"
"vaddw.s16 q13, q13, d13\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]!\n"
"vaddw.s16 q14, q14, d14\n"
"vaddw.s16 q15, q15, d15\n"
"subs %[cnt], #1\n"
"vst1.32 {d24-d27}, [%[ptr_out0]]!\n"
"vst1.32 {d28-d31}, [%[ptr_out0]]!\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[ptr_out0] "+r"(ptr_out0),
[wptr] "+r"(wptr)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
block_inr0 = block_inr2;
block_inr1 = block_inr0 + in_len;
block_inr2 = block_inr1 + in_len;
}
write_int32_nchwc8_to_nchw<Dtype>(pre_out,
reinterpret_cast<Dtype*>(dout_batch),
c,
c + hout_c_block,
h,
h + h_kernel,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
bias_local,
flag_bias,
ptr_write,
scale + c);
}
}
}
}
template void conv_depthwise_3x3s2_int8<int8_t>(int8_t* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
template void conv_depthwise_3x3s2_int8<float>(float* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -51,12 +51,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -51,12 +51,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
const int win_round = wout_round * 2 /*stride_w*/ + 1; const int win_round = wout_round * 2 /*stride_w*/ + 1;
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
// if (param.activation_param.has_active) {
// if (param.activation_param.active == Active_relu &&
// fabs(param.activation_param.negative_slope) < 1e-6f) {
// flag_relu = true;
// }
// }
//! get h block //! get h block
//! win_round * ic * hin_r_block + wout_round * hout_c_block * hout_r_block //! win_round * ic * hin_r_block + wout_round * hout_c_block * hout_r_block
//! * threads = l2_size //! * threads = l2_size
......
...@@ -28,8 +28,9 @@ namespace math { ...@@ -28,8 +28,9 @@ namespace math {
#ifdef __aarch64__ #ifdef __aarch64__
int conv_3x3s2_direct_int8_c_num() { return 8; } int conv_3x3s2_direct_int8_c_num() { return 8; }
template <typename Dtype>
void conv_3x3s2_direct_int8(const int8_t* din, void conv_3x3s2_direct_int8(const int8_t* din,
int32_t* dout, Dtype* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
...@@ -38,27 +39,25 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -38,27 +39,25 @@ void conv_3x3s2_direct_int8(const int8_t* din,
int hin, int hin,
int win, int win,
const int8_t* weights, const int8_t* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, Context<TARGET(kARM)>* ctx,
PrecisionType out_type,
const float* scale) { const float* scale) {
//! 3x3s2 int8 convolution, implemented by direct algorithm //! 3x3s2 int8 convolution, implemented by direct algorithm
//! prepack input to tmp buffer //! prepack input to tmp buffer
//! write output to tmp buffer //! write output to tmp buffer
int threads = ctx->threads();
int stride_w = param.strides[1];
int pad_w = param.paddings[1];
int pad_h = param.paddings[0];
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = (param.bias != nullptr); bool flag_bias = param.bias;
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
const int threads = ctx->threads();
int llc_size = ctx->llc_size() / 4;
//! set 2/3 l2 cache
int l2_size = ctx->llc_size() / 3 * 2;
const int hout_c_block = 8; const int hout_c_block = 8;
const int hout_r_kernel = 2; const int hout_r_kernel = 2;
const int wout_round = ((wout + 3) / 4) * 4; const int wout_round = ((wout + 3) / 4) * 4;
const int win_round = wout_round * stride_w + 1; const int win_round = wout_round * 2 /*stride_w*/ + 1;
//! get h block //! get h block
//! win_round * chin * hin_r_block * sizeof(int8_t) + wout_round * //! win_round * chin * hin_r_block * sizeof(int8_t) + wout_round *
...@@ -66,7 +65,7 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -66,7 +65,7 @@ void conv_3x3s2_direct_int8(const int8_t* din,
//! win_round = 2 * wout_round + 1 //! win_round = 2 * wout_round + 1
//! hin_r_block = 2 * hout_r_block + 1 //! hin_r_block = 2 * hout_r_block + 1
int hout_r_block = int hout_r_block =
(l2_size - 2 * wout_round * chin - chin) / (llc_size - 2 * wout_round * chin - chin) /
((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads * 4); ((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel;
...@@ -74,16 +73,15 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -74,16 +73,15 @@ void conv_3x3s2_direct_int8(const int8_t* din,
const int hin_r_block = hout_r_block * 2 + 1; const int hin_r_block = hout_r_block * 2 + 1;
int8_t* tmp_work_space = ctx->workspace_data<int8_t>(); auto tmp_work_space = ctx->workspace_data<int8_t>();
int zero_size = chout > (win_round + 3) / 4 ? chout : (win_round + 3) / 4; int zero_size = chout > (win_round + 3) / 4 ? chout : (win_round + 3) / 4;
const int kZeroSize = zero_size; int32_t ptr_zero[zero_size]; // NOLINT
int32_t ptr_zero[kZeroSize];
memset(ptr_zero, 0, sizeof(int32_t) * zero_size); memset(ptr_zero, 0, sizeof(int32_t) * zero_size);
const int kWoutRound = wout_round; Dtype ptr_write[wout_round]; // NOLINT
int32_t ptr_write[kWoutRound];
int in_len = win_round * chin; int in_len = win_round * chin;
int pre_in_size = hin_r_block * in_len; int pre_in_size = hin_r_block * in_len;
pre_in_size = ROUNDUP(pre_in_size, 4);
int pre_out_size = hout_c_block * hout_r_block * wout_round; int pre_out_size = hout_c_block * hout_r_block * wout_round;
//! l2_cache start //! l2_cache start
...@@ -100,10 +98,8 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -100,10 +98,8 @@ void conv_3x3s2_direct_int8(const int8_t* din,
int out_row_stride = hout_c_block * wout_round; int out_row_stride = hout_c_block * wout_round;
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const int8_t* din_batch = din + n * chin * size_in_channel; auto din_batch = din + n * chin * size_in_channel;
int8_t* dout_batch = auto dout_batch = dout + n * chout * size_out_channel;
reinterpret_cast<int8_t*>(dout) +
n * chout * size_out_channel * PrecisionTypeLength(out_type);
for (int h = 0; h < hout; h += hout_r_block) { for (int h = 0; h < hout; h += hout_r_block) {
int h_kernel = hout_r_block; int h_kernel = hout_r_block;
if (h + hout_r_block > hout) { if (h + hout_r_block > hout) {
...@@ -133,12 +129,10 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -133,12 +129,10 @@ void conv_3x3s2_direct_int8(const int8_t* din,
#pragma omp parallel for num_threads(threads) #pragma omp parallel for num_threads(threads)
for (int c = 0; c < chout; c += hout_c_block) { for (int c = 0; c < chout; c += hout_c_block) {
#ifdef ARM_WITH_OMP #ifdef ARM_WITH_OMP
int32_t* pre_out = auto pre_out = reinterpret_cast<int*>(pre_din + pre_in_size) +
reinterpret_cast<int*>(pre_din + (pre_in_size + 3) / 4 * 4) + omp_get_thread_num() * pre_out_size;
omp_get_thread_num() * pre_out_size;
#else #else
int32_t* pre_out = auto pre_out = reinterpret_cast<int32_t*>(pre_din + pre_in_size);
reinterpret_cast<int32_t*>(pre_din + (pre_in_size + 3) / 4 * 4);
#endif #endif
const int8_t* block_inr0 = cblock_inr0; const int8_t* block_inr0 = cblock_inr0;
const int8_t* block_inr1 = cblock_inr1; const int8_t* block_inr1 = cblock_inr1;
...@@ -147,12 +141,19 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -147,12 +141,19 @@ void conv_3x3s2_direct_int8(const int8_t* din,
const int8_t* block_inr4 = cblock_inr4; const int8_t* block_inr4 = cblock_inr4;
const int8_t* weight_c = weights + c * w_stride; const int8_t* weight_c = weights + c * w_stride;
const int32_t* bias_ptr = ptr_zero; float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (flag_bias) { if (flag_bias) {
bias_ptr = bias + c; bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
bias_local[4] = bias[c + 4];
bias_local[5] = bias[c + 5];
bias_local[6] = bias[c + 6];
bias_local[7] = bias[c + 7];
} }
fill_packed_bias_nxmw_int8(bias_ptr, pre_out, 8, h_kernel, wout_round); memset(pre_out, 0, pre_out_size * sizeof(int32_t));
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
const int8_t* wc0 = weight_c; const int8_t* wc0 = weight_c;
...@@ -186,490 +187,236 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -186,490 +187,236 @@ void conv_3x3s2_direct_int8(const int8_t* din,
int32_t* ptr_out0 = pre_out0; int32_t* ptr_out0 = pre_out0;
int32_t* ptr_out1 = pre_out1; int32_t* ptr_out1 = pre_out1;
int cnt = w_loop; int cnt = w_loop;
// clang-format off
asm volatile( asm volatile(
"ldr q0, [%[r0]], #8 \n" /* load input r0 */ "ldr q0, [%[r0]], #8 \n" /* load input r0 */
"ldr q1, [%[r2]], #8 \n" /* load input r2 */ "ldr q1, [%[r2]], #8 \n" /* load input r2 */
"sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */ "sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */
"sshll v1.8h, v1.8b, #0 \n" /* r1: int8 -> int16*/ "sshll v1.8h, v1.8b, #0 \n" /* r1: int8 -> int16*/
"1: \n" /* main loop */ "1: \n" /* main loop */
/* r0, r2 mul w00 */
/* r0, r2 mul w00 */ "smull v4.4s, %[v0].4h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/
"smull v4.4s, %[v0].4h, v0.h[0]\n" /* outr00 = v0 * r0[0] "smull2 v5.4s, %[v0].8h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/
*/ "smull v6.4s, %[v0].4h, v0.h[2]\n" /* outr01 = v0 * r0[2]*/
"smull2 v5.4s, %[v0].8h, v0.h[0]\n" /* outr00 = v0 * r0[0] "smull2 v7.4s, %[v0].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/
*/ "smull v8.4s, %[v0].4h, v0.h[4]\n" /* outr02 = v0 * r0[4]*/
"smull v6.4s, %[v0].4h, v0.h[2]\n" /* outr01 = v0 * r0[2] "smull2 v9.4s, %[v0].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/
*/ "smull v10.4s, %[v0].4h, v0.h[6]\n" /* outr03 = v0 * r0[6]*/
"smull2 v7.4s, %[v0].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] "smull2 v11.4s, %[v0].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/
*/ "smull v12.4s, %[v0].4h, v1.h[0]\n" /* outr10 = v0 * r2[0]*/
"smull v8.4s, %[v0].4h, v0.h[4]\n" /* outr02 = v0 * r0[4] "smull2 v13.4s, %[v0].8h, v1.h[0]\n" /* outr11 = v0 * r2[2]*/
*/ "smull v14.4s, %[v0].4h, v1.h[2]\n" /* outr12 = v0 * r2[4]*/
"smull2 v9.4s, %[v0].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] "smull2 v15.4s, %[v0].8h, v1.h[2]\n" /* outr13 = v0 * r2[6]*/
*/ "smull v16.4s, %[v0].4h, v1.h[4]\n" /* outr10 = v0 * r2[0]*/
"smull v10.4s, %[v0].4h, v0.h[6]\n" /* outr03 = v0 * r0[6] "smull2 v17.4s, %[v0].8h, v1.h[4]\n" /* outr11 = v0 * r2[2]*/
*/ "smull v18.4s, %[v0].4h, v1.h[6]\n" /* outr12 = v0 * r2[4]*/
"smull2 v11.4s, %[v0].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] "smull2 v19.4s, %[v0].8h, v1.h[6]\n" /* outr13 = v0 * r2[6]*/
*/ /* r2, mul w06 */
"smlal v4.4s, %[v6].4h, v1.h[0]\n" /* outr00 = v6 * r2[1]*/
"smull v12.4s, %[v0].4h, v1.h[0]\n" /* outr10 = v0 * r2[0] "smlal2 v5.4s, %[v6].8h, v1.h[0]\n" /* outr01 = v6 * r2[3]*/
*/ "smlal v6.4s, %[v6].4h, v1.h[2]\n" /* outr02 = v6 * r2[5]*/
"smull2 v13.4s, %[v0].8h, v1.h[0]\n" /* outr11 = v0 * r2[2] "smlal2 v7.4s, %[v6].8h, v1.h[2]\n" /* outr03 = v6 * r2[7]*/
*/ "smlal v8.4s, %[v6].4h, v1.h[4]\n" /* outr00 = v6 * r2[1]*/
"smull v14.4s, %[v0].4h, v1.h[2]\n" /* outr12 = v0 * r2[4] "smlal2 v9.4s, %[v6].8h, v1.h[4]\n" /* outr01 = v6 * r2[3]*/
*/ "smlal v10.4s, %[v6].4h, v1.h[6]\n" /* outr02 = v6 * r2[5]*/
"smull2 v15.4s, %[v0].8h, v1.h[2]\n" /* outr13 = v0 * r2[6] "smlal2 v11.4s, %[v6].8h, v1.h[6]\n" /* outr03 = v6 * r2[7]*/
*/ "ldr q2, [%[r0]]\n" /* load r0, 9th data,v10.s[0] */
"smull v16.4s, %[v0].4h, v1.h[4]\n" /* outr10 = v0 * r2[0] /* r0, r2, mul w01 */
*/ "smlal v4.4s, %[v1].4h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/
"smull2 v17.4s, %[v0].8h, v1.h[4]\n" /* outr11 = v0 * r2[2] "smlal2 v5.4s, %[v1].8h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal v6.4s, %[v1].4h, v0.h[3]\n" /* outr01 = v0 * r0[2]*/
"smull v18.4s, %[v0].4h, v1.h[6]\n" /* outr12 = v0 * r2[4] "smlal2 v7.4s, %[v1].8h, v0.h[3]\n" /* outr00 = v0 * r0[0]*/
*/ "sshll v2.8h, v2.8b, #0 \n" /* r0: int8 -> int16 */
"smull2 v19.4s, %[v0].8h, v1.h[6]\n" /* outr13 = v0 * r2[6] "smlal v8.4s, %[v1].4h, v0.h[5]\n" /* outr02 = v0 * r0[4]*/
*/ "smlal2 v9.4s, %[v1].8h, v0.h[5]\n" /* outr00 = v0 * r0[0]*/
"smlal v10.4s, %[v1].4h, v0.h[7]\n" /* outr03 = v0 * r0[6]*/
/* r2, mul w06 */ "smlal2 v11.4s, %[v1].8h, v0.h[7]\n" /* outr00 = v0 * r0[0]*/
"smlal v4.4s, %[v6].4h, v1.h[0]\n" /* outr00 = v6 * r2[1] "smlal v12.4s, %[v1].4h, v1.h[1]\n" /* outr10 = v0 * r2[0]*/
*/ "smlal2 v13.4s, %[v1].8h, v1.h[1]\n" /* outr11 = v0 * r2[2]*/
"smlal2 v5.4s, %[v6].8h, v1.h[0]\n" /* outr01 = v6 * r2[3] "smlal v14.4s, %[v1].4h, v1.h[3]\n" /* outr12 = v0 * r2[4]*/
*/ "smlal2 v15.4s, %[v1].8h, v1.h[3]\n" /* outr13 = v0 * r2[6]*/
"smlal v6.4s, %[v6].4h, v1.h[2]\n" /* outr02 = v6 * r2[5] "smlal v16.4s, %[v1].4h, v1.h[5]\n" /* outr10 = v0 * r2[0]*/
*/ "smlal2 v17.4s, %[v1].8h, v1.h[5]\n" /* outr11 = v0 * r2[2]*/
"smlal2 v7.4s, %[v6].8h, v1.h[2]\n" /* outr03 = v6 * r2[7] "smlal v18.4s, %[v1].4h, v1.h[7]\n" /* outr12 = v0 * r2[4]*/
*/ "smlal2 v19.4s, %[v1].8h, v1.h[7]\n" /* outr13 = v0 * r2[6]*/
"smlal v8.4s, %[v6].4h, v1.h[4]\n" /* outr00 = v6 * r2[1] /* r2, mul w07 */
*/ "smlal v4.4s, %[v7].4h, v1.h[1]\n" /* outr00 = v6 * r2[1]*/
"smlal2 v9.4s, %[v6].8h, v1.h[4]\n" /* outr01 = v6 * r2[3] "smlal2 v5.4s, %[v7].8h, v1.h[1]\n" /* outr01 = v6 * r2[3]*/
*/ "smlal v6.4s, %[v7].4h, v1.h[3]\n" /* outr02 = v6 * r2[5]*/
"smlal v10.4s, %[v6].4h, v1.h[6]\n" /* outr02 = v6 * r2[5] "smlal2 v7.4s, %[v7].8h, v1.h[3]\n" /* outr03 = v6 * r2[7]*/
*/ "smlal v8.4s, %[v7].4h, v1.h[5]\n" /* outr00 = v6 * r2[1]*/
"smlal2 v11.4s, %[v6].8h, v1.h[6]\n" /* outr03 = v6 * r2[7] "smlal2 v9.4s, %[v7].8h, v1.h[5]\n" /* outr01 = v6 * r2[3]*/
*/ "smlal v10.4s, %[v7].4h, v1.h[7]\n" /* outr02 = v6 * r2[5]*/
"smlal2 v11.4s, %[v7].8h, v1.h[7]\n" /* outr03 = v6 * r2[7]*/
"ldr q2, [%[r0]] \n" /* load r0, 9th "ldr q3, [%[r2]]\n" /* load r2, 9th data,v11.s[0] */
data,v10.s[0] */ /* r0, r2, mul w02 */
"smlal v4.4s, %[v2].4h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/
/* r0, r2, mul w01 */ "smlal2 v5.4s, %[v2].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/
"smlal v4.4s, %[v1].4h, v0.h[1]\n" /* outr00 = v0 * r0[0] "smlal v6.4s, %[v2].4h, v0.h[4]\n" /* outr01 = v0 * r0[2]*/
*/ "smlal2 v7.4s, %[v2].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/
"smlal2 v5.4s, %[v1].8h, v0.h[1]\n" /* outr00 = v0 * r0[0] "sshll v3.8h, v3.8b, #0 \n" /* r2: int8 -> int16*/
*/ "smlal v8.4s, %[v2].4h, v0.h[6]\n" /* outr02 = v0 * r0[4]*/
"smlal v6.4s, %[v1].4h, v0.h[3]\n" /* outr01 = v0 * r0[2] "smlal2 v9.4s, %[v2].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal v10.4s, %[v2].4h, v2.h[0]\n" /* outr03 = v0 * r0[6]*/
"smlal2 v7.4s, %[v1].8h, v0.h[3]\n" /* outr00 = v0 * r0[0] "smlal2 v11.4s, %[v2].8h, v2.h[0]\n" /* outr00 = v0 * r0[0]*/
*/ "ldr q0, [%[r1]], #8 \n" /* load input r1 */
"sshll v2.8h, v2.8b, #0 \n" /* r0: int8 -> int16 */ "smlal v12.4s, %[v2].4h, v1.h[2]\n" /* outr10 = v0 * r2[0]*/
"smlal v8.4s, %[v1].4h, v0.h[5]\n" /* outr02 = v0 * r0[4] "smlal2 v13.4s, %[v2].8h, v1.h[2]\n" /* outr11 = v0 * r2[2]*/
*/ "smlal v14.4s, %[v2].4h, v1.h[4]\n" /* outr12 = v0 * r2[4]*/
"smlal2 v9.4s, %[v1].8h, v0.h[5]\n" /* outr00 = v0 * r0[0] "smlal2 v15.4s, %[v2].8h, v1.h[4]\n" /* outr13 = v0 * r2[6]*/
*/ "sshll v0.8h, v0.8b, #0 \n" /* r1 : int8 -> int16 */
"smlal v10.4s, %[v1].4h, v0.h[7]\n" /* outr03 = v0 * r0[6] "smlal v16.4s, %[v2].4h, v1.h[6]\n" /* outr10 = v0 * r2[0]*/
*/ "smlal2 v17.4s, %[v2].8h, v1.h[6]\n" /* outr11 = v0 * r2[2]*/
"smlal2 v11.4s, %[v1].8h, v0.h[7]\n" /* outr00 = v0 * r0[0] "smlal v18.4s, %[v2].4h, v3.h[0]\n" /* outr12 = v0 * r2[4]*/
*/ "smlal2 v19.4s, %[v2].8h, v3.h[0]\n" /* outr13 = v0 * r2[6]*/
/* r2, mul w08 */
"smlal v12.4s, %[v1].4h, v1.h[1]\n" /* outr10 = v0 * r2[0] "smlal v4.4s, %[v8].4h, v1.h[2]\n" /* outr00 = v6 * r2[1]*/
*/ "smlal2 v5.4s, %[v8].8h, v1.h[2]\n" /* outr01 = v6 * r2[3]*/
"smlal2 v13.4s, %[v1].8h, v1.h[1]\n" /* outr11 = v0 * r2[2] "smlal v6.4s, %[v8].4h, v1.h[4]\n" /* outr02 = v6 * r2[5]*/
*/ "smlal2 v7.4s, %[v8].8h, v1.h[4]\n" /* outr03 = v6 * r2[7]*/
"smlal v14.4s, %[v1].4h, v1.h[3]\n" /* outr12 = v0 * r2[4] "smlal v8.4s, %[v8].4h, v1.h[6]\n" /* outr00 = v6 * r2[1]*/
*/ "smlal2 v9.4s, %[v8].8h, v1.h[6]\n" /* outr01 = v6 * r2[3]*/
"smlal2 v15.4s, %[v1].8h, v1.h[3]\n" /* outr13 = v0 * r2[6] "smlal v10.4s, %[v8].4h, v3.h[0]\n" /* outr02 = v6 * r2[5]*/
*/ "smlal2 v11.4s, %[v8].8h, v3.h[0]\n" /* outr03 = v6 * r2[7]*/
"smlal v16.4s, %[v1].4h, v1.h[5]\n" /* outr10 = v0 * r2[0] "ldr q1, [%[r3]], #8 \n" /* load input r3 */
*/ /* r1, r3, mul w03 */
"smlal2 v17.4s, %[v1].8h, v1.h[5]\n" /* outr11 = v0 * r2[2] "smlal v4.4s, %[v3].4h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal2 v5.4s, %[v3].8h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/
"smlal v18.4s, %[v1].4h, v1.h[7]\n" /* outr12 = v0 * r2[4] "smlal v6.4s, %[v3].4h, v0.h[2]\n" /* outr01 = v0 * r0[2]*/
*/ "smlal2 v7.4s, %[v3].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/
"smlal2 v19.4s, %[v1].8h, v1.h[7]\n" /* outr13 = v0 * r2[6] "sshll v1.8h, v1.8b, #0 \n" /* r3: int8 -> int16 */
*/ "smlal v8.4s, %[v3].4h, v0.h[4]\n" /* outr02 = v0 * r0[4]*/
"smlal2 v9.4s, %[v3].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/
/* r2, mul w07 */ "smlal v10.4s, %[v3].4h, v0.h[6]\n" /* outr03 = v0 * r0[6]*/
"smlal v4.4s, %[v7].4h, v1.h[1]\n" /* outr00 = v6 * r2[1] "smlal2 v11.4s, %[v3].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/
*/ "ldr q2, [%[r1]]\n" /* load r1, 9th data,v10.s[0] */
"smlal2 v5.4s, %[v7].8h, v1.h[1]\n" /* outr01 = v6 * r2[3] "smlal v12.4s, %[v3].4h, v1.h[0]\n" /* outr10 = v0 * r2[0]*/
*/ "smlal2 v13.4s, %[v3].8h, v1.h[0]\n" /* outr11 = v0 * r2[2]*/
"smlal v6.4s, %[v7].4h, v1.h[3]\n" /* outr02 = v6 * r2[5] "smlal v14.4s, %[v3].4h, v1.h[2]\n" /* outr12 = v0 * r2[4]*/
*/ "smlal2 v15.4s, %[v3].8h, v1.h[2]\n" /* outr13 = v0 * r2[6]*/
"smlal2 v7.4s, %[v7].8h, v1.h[3]\n" /* outr03 = v6 * r2[7] "ldr q3, [%[r3]]\n" /* load r3, 9th data,v11.s[0] */
*/ "smlal v16.4s, %[v3].4h, v1.h[4]\n" /* outr10 = v0 * r2[0]*/
"smlal v8.4s, %[v7].4h, v1.h[5]\n" /* outr00 = v6 * r2[1] "smlal2 v17.4s, %[v3].8h, v1.h[4]\n" /* outr11 = v0 * r2[2]*/
*/ "smlal v18.4s, %[v3].4h, v1.h[6]\n" /* outr12 = v0 * r2[4]*/
"smlal2 v9.4s, %[v7].8h, v1.h[5]\n" /* outr01 = v6 * r2[3] "smlal2 v19.4s, %[v3].8h, v1.h[6]\n" /* outr13 = v0 * r2[6]*/
*/ "sshll v2.8h, v2.8b, #0 \n" /* r1 : int8 -> int16 */
"smlal v10.4s, %[v7].4h, v1.h[7]\n" /* outr02 = v6 * r2[5] /* r1, r3, mul w05 */
*/ "smlal v4.4s, %[v5].4h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/
"smlal2 v11.4s, %[v7].8h, v1.h[7]\n" /* outr03 = v6 * r2[7] "smlal2 v5.4s, %[v5].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal v6.4s, %[v5].4h, v0.h[4]\n" /* outr01 = v0 * r0[2]*/
"smlal2 v7.4s, %[v5].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/
"ldr q3, [%[r2]] \n" /* load r2, 9th "sshll v3.8h, v3.8b, #0 \n" /* r3 : int8 -> int16 */
data,v11.s[0] */ "smlal v8.4s, %[v5].4h, v0.h[6]\n" /* outr02 = v0 * r0[4]*/
"smlal2 v9.4s, %[v5].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/
/* r0, r2, mul w02 */ "smlal v10.4s, %[v5].4h, v2.h[0]\n" /* outr03 = v0 * r0[6]*/
"smlal v4.4s, %[v2].4h, v0.h[2]\n" /* outr00 = v0 * r0[0] "smlal2 v11.4s, %[v5].8h, v2.h[0]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal v12.4s, %[v5].4h, v1.h[2]\n" /* outr10 = v0 * r2[0]*/
"smlal2 v5.4s, %[v2].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] "smlal2 v13.4s, %[v5].8h, v1.h[2]\n" /* outr11 = v0 * r2[2]*/
*/ "smlal v14.4s, %[v5].4h, v1.h[4]\n" /* outr12 = v0 * r2[4]*/
"smlal v6.4s, %[v2].4h, v0.h[4]\n" /* outr01 = v0 * r0[2] "smlal2 v15.4s, %[v5].8h, v1.h[4]\n" /* outr13 = v0 * r2[6]*/
*/ "smlal v16.4s, %[v5].4h, v1.h[6]\n" /* outr10 = v0 * r2[0]*/
"smlal2 v7.4s, %[v2].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] "smlal2 v17.4s, %[v5].8h, v1.h[6]\n" /* outr11 = v0 * r2[2]*/
*/ "smlal v18.4s, %[v5].4h, v3.h[0]\n" /* outr12 = v0 * r2[4]*/
"sshll v3.8h, v3.8b, #0 \n" /* r2: int8 -> int16*/ "smlal2 v19.4s, %[v5].8h, v3.h[0]\n" /* outr13 = v0 * r2[6]*/
"smlal v8.4s, %[v2].4h, v0.h[6]\n" /* outr02 = v0 * r0[4] "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */
*/ /* r1, r3, mul w04 */
"smlal2 v9.4s, %[v2].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] "smlal v4.4s, %[v4].4h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal2 v5.4s, %[v4].8h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/
"smlal v10.4s, %[v2].4h, v2.h[0]\n" /* outr03 = v0 * r0[6] "smlal v6.4s, %[v4].4h, v0.h[3]\n" /* outr01 = v0 * r0[2]*/
*/ "smlal2 v7.4s, %[v4].8h, v0.h[3]\n" /* outr00 = v0 * r0[0]*/
"smlal2 v11.4s, %[v2].8h, v2.h[0]\n" /* outr00 = v0 * r0[0] "smlal v8.4s, %[v4].4h, v0.h[5]\n" /* outr02 = v0 * r0[4]*/
*/ "smlal2 v9.4s, %[v4].8h, v0.h[5]\n" /* outr00 = v0 * r0[0]*/
"smlal v10.4s, %[v4].4h, v0.h[7]\n" /* outr03 = v0 * r0[6]*/
"ldr q0, [%[r1]], #8 \n" /* load input r1 */ "smlal2 v11.4s, %[v4].8h, v0.h[7]\n" /* outr00 = v0 * r0[0]*/
"ldr q0, [%[r4]], #8 \n" /* load input r4 */
"smlal v12.4s, %[v2].4h, v1.h[2]\n" /* outr10 = v0 * r2[0] "smlal v12.4s, %[v4].4h, v1.h[1]\n" /* outr10 = v0 * r2[0]*/
*/ "smlal2 v13.4s, %[v4].8h, v1.h[1]\n" /* outr11 = v0 * r2[2]*/
"smlal2 v13.4s, %[v2].8h, v1.h[2]\n" /* outr11 = v0 * r2[2] "smlal v14.4s, %[v4].4h, v1.h[3]\n" /* outr12 = v0 * r2[4]*/
*/ "smlal2 v15.4s, %[v4].8h, v1.h[3]\n" /* outr13 = v0 * r2[6]*/
"smlal v14.4s, %[v2].4h, v1.h[4]\n" /* outr12 = v0 * r2[4] "sshll v0.8h, v0.8b, #0 \n" /* r4 : int8 -> int16 */
*/ "smlal v16.4s, %[v4].4h, v1.h[5]\n" /* outr10 = v0 * r2[0]*/
"smlal2 v15.4s, %[v2].8h, v1.h[4]\n" /* outr13 = v0 * r2[6] "smlal2 v17.4s, %[v4].8h, v1.h[5]\n" /* outr11 = v0 * r2[2]*/
*/ "smlal v18.4s, %[v4].4h, v1.h[7]\n" /* outr12 = v0 * r2[4]*/
"sshll v0.8h, v0.8b, #0 \n" /* r1 : int8 -> int16 */ "smlal2 v19.4s, %[v4].8h, v1.h[7]\n" /* outr13 = v0 * r2[6]*/
"smlal v16.4s, %[v2].4h, v1.h[6]\n" /* outr10 = v0 * r2[0] "ldr q2, [%[r4]]\n" /* load r4, 9th data,v10.s[0] */
*/ "sshll v2.8h, v2.8b, #0\n" /* r4 : int8 -> int16 */
"smlal2 v17.4s, %[v2].8h, v1.h[6]\n" /* outr11 = v0 * r2[2] "ldp q1, q3, [%[ptr_out0]]\n" /* load ptr_out */
*/ "ldp q20, q21, [%[ptr_out0], #32]\n" /* load ptr_out */
"smlal v18.4s, %[v2].4h, v3.h[0]\n" /* outr12 = v0 * r2[4] "add v4.4s, v1.4s , v4.4s\n" /* v10 = outr00[0].low + q2 */
*/ "add v5.4s, v3.4s , v5.4s\n" /* v11 = outr00[0].high+ q3 */
"smlal2 v19.4s, %[v2].8h, v3.h[0]\n" /* outr13 = v0 * r2[6] "add v6.4s, v20.4s, v6.4s\n" /* v12 = outr01[0].low + q4 */
*/ "add v7.4s, v21.4s, v7.4s\n" /* v13 = outr01[0].high+ q5 */
"ldp q1 , q3 , [%[ptr_out0], #64]\n" /* load ptr_out*/
/* r2, mul w08 */ "ldp q20, q21, [%[ptr_out0], #96]\n" /* load ptr_out*/
"smlal v4.4s, %[v8].4h, v1.h[2]\n" /* outr00 = v6 * r2[1] "stp q4, q5 , [%[ptr_out0]], #32\n" /* store q10, q11*/
*/ "stp q6, q7 , [%[ptr_out0]], #32\n" /* store q10, q11*/
"smlal2 v5.4s, %[v8].8h, v1.h[2]\n" /* outr01 = v6 * r2[3] "add v8.4s , v1.4s , v8.4s\n" /* v10 = outr00[0].low+ q2 */
*/ "add v9.4s , v3.4s , v9.4s\n" /* v11 = outr00[0].high+q3 */
"smlal v6.4s, %[v8].4h, v1.h[4]\n" /* outr02 = v6 * r2[5] "add v10.4s, v20.4s, v10.4s\n" /* v12 = outr01[0].low+q4 */
*/ "add v11.4s, v21.4s, v11.4s\n" /* v13 = outr01[0].high+q5 */
"smlal2 v7.4s, %[v8].8h, v1.h[4]\n" /* outr03 = v6 * r2[7] "stp q8, q9, [%[ptr_out0]], #32\n" /* store q14, q15*/
*/ "stp q10, q11, [%[ptr_out0]], #32\n" /* store q16, q17*/
"smlal v8.4s, %[v8].4h, v1.h[6]\n" /* outr00 = v6 * r2[1] /* r4, mul w08 */
*/ "smlal v12.4s, %[v8].4h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/
"smlal2 v9.4s, %[v8].8h, v1.h[6]\n" /* outr01 = v6 * r2[3] "smlal2 v13.4s, %[v8].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal v14.4s, %[v8].4h, v0.h[4]\n" /* outr01 = v0 * r0[2]*/
"smlal v10.4s, %[v8].4h, v3.h[0]\n" /* outr02 = v6 * r2[5] "smlal2 v15.4s, %[v8].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal v16.4s, %[v8].4h, v0.h[6]\n" /* outr02 = v0 * r0[4]*/
"smlal2 v11.4s, %[v8].8h, v3.h[0]\n" /* outr03 = v6 * r2[7] "smlal2 v17.4s, %[v8].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal v18.4s, %[v8].4h, v2.h[0]\n" /* outr03 = v0 * r0[6]*/
"smlal2 v19.4s, %[v8].8h, v2.h[0]\n" /* outr00 = v0 * r0[0]*/
"ldr q1, [%[r3]], #8 \n" /* load input r3 */ /* r4, mul w07 */
"smlal v12.4s, %[v7].4h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/
/* r1, r3, mul w03 */ "smlal2 v13.4s, %[v7].8h, v0.h[1]\n" /* outr00 = v0 * r0[0]*/
"smlal v4.4s, %[v3].4h, v0.h[0]\n" /* outr00 = v0 * r0[0] "smlal v14.4s, %[v7].4h, v0.h[3]\n" /* outr01 = v0 * r0[2]*/
*/ "smlal2 v15.4s, %[v7].8h, v0.h[3]\n" /* outr00 = v0 * r0[0]*/
"smlal2 v5.4s, %[v3].8h, v0.h[0]\n" /* outr00 = v0 * r0[0] "ldr q1, [%[r2]], #8 \n" /* load input r2 */
*/ "smlal v16.4s, %[v7].4h, v0.h[5]\n" /* outr02 = v0 * r0[4]*/
"smlal v6.4s, %[v3].4h, v0.h[2]\n" /* outr01 = v0 * r0[2] "smlal2 v17.4s, %[v7].8h, v0.h[5]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal v18.4s, %[v7].4h, v0.h[7]\n" /* outr03 = v0 * r0[6]*/
"smlal2 v7.4s, %[v3].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] "smlal2 v19.4s, %[v7].8h, v0.h[7]\n" /* outr00 = v0 * r0[0]*/
*/ "sshll v1.8h, v1.8b, #0 \n" /* r2: int8 -> int16*/
"sshll v1.8h, v1.8b, #0 \n" /* r3: int8 -> int16 */ /* r4, mul w06 */
"smlal v8.4s, %[v3].4h, v0.h[4]\n" /* outr02 = v0 * r0[4] "ldp q4, q5, [%[ptr_out1]] \n" /* load ptr_out*/
*/ "smlal v12.4s, %[v6].4h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/
"smlal2 v9.4s, %[v3].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] "smlal2 v13.4s, %[v6].8h, v0.h[0]\n" /* outr00 = v0 * r0[0]*/
*/ "smlal v14.4s, %[v6].4h, v0.h[2]\n" /* outr01 = v0 * r0[2]*/
"smlal v10.4s, %[v3].4h, v0.h[6]\n" /* outr03 = v0 * r0[6] "ldp q8, q9, [%[ptr_out1], #64]\n" /* load ptr_out*/
*/ "smlal2 v15.4s, %[v6].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]*/
"smlal2 v11.4s, %[v3].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] "smlal v16.4s, %[v6].4h, v0.h[4]\n" /* outr02 = v0 * r0[4]*/
*/ "smlal2 v17.4s, %[v6].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]*/
"ldr q2, [%[r1]] \n" /* load r1, 9th "ldp q10, q11, [%[ptr_out1], #96]\n" /* load ptr_out*/
data,v10.s[0] */ "smlal v18.4s, %[v6].4h, v0.h[6]\n" /* outr03 = v0 * r0[6]*/
"smlal2 v19.4s, %[v6].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]*/
"smlal v12.4s, %[v3].4h, v1.h[0]\n" /* outr10 = v0 * r2[0] "ldr q0, [%[r0]], #8 \n" /* load input r2 */
*/ "ldp q6, q7, [%[ptr_out1], #32]\n" /* load ptr_out*/
"smlal2 v13.4s, %[v3].8h, v1.h[0]\n" /* outr11 = v0 * r2[2] "sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */
*/ /* store outr1 */
"smlal v14.4s, %[v3].4h, v1.h[2]\n" /* outr12 = v0 * r2[4] "add v12.4s, v4.4s , v12.4s\n" /* v10 = outr10[0].low + q2 */
*/ "add v13.4s, v5.4s , v13.4s\n" /* v11 = outr10[0].high + q3 */
"smlal2 v15.4s, %[v3].8h, v1.h[2]\n" /* outr13 = v0 * r2[6] "add v14.4s, v6.4s , v14.4s\n" /* v12 = outr11[0].low + q4 */
*/ "add v15.4s, v7.4s , v15.4s\n" /* v13 = outr11[0].high + q5 */
"ldr q3, [%[r3]] \n" /* load r3, 9th "stp q12, q13, [%[ptr_out1]], #32\n" /* store q10, q11*/
data,v11.s[0] */ "add v16.4s, v8.4s , v16.4s\n" /* v14 = outr12[0].low + q6 */
"smlal v16.4s, %[v3].4h, v1.h[4]\n" /* outr10 = v0 * r2[0] "add v17.4s, v9.4s , v17.4s\n" /* v15 = outr12[0].high + q7 */
*/ "stp q14, q15, [%[ptr_out1]], #32\n" /* store q12, q13*/
"smlal2 v17.4s, %[v3].8h, v1.h[4]\n" /* outr11 = v0 * r2[2] "add v18.4s, v10.4s, v18.4s\n" /* v16 = outr13[0].low + q8 */
*/ "add v19.4s, v11.4s, v19.4s\n" /* v17 = outr13[0].high + q9 */
"smlal v18.4s, %[v3].4h, v1.h[6]\n" /* outr12 = v0 * r2[4] "stp q16, q17, [%[ptr_out1]], #32\n" /* store q14, q15*/
*/ "stp q18, q19, [%[ptr_out1]], #32\n" /* store q16, q17*/
"smlal2 v19.4s, %[v3].8h, v1.h[6]\n" /* outr13 = v0 * r2[6] "bne 1b\n" /* jump to main loop */
*/ : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1),
"sshll v2.8h, v2.8b, #0 \n" /* r1 : int8 -> int16 */ [r2] "+r"(r2), [r3] "+r"(r3), [r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0), [ptr_out1] "+r"(ptr_out1)
/* r1, r3, mul w05 */ : [v0] "w"(v0), [v1] "w"(v1), [v2] "w"(v2),
"smlal v4.4s, %[v5].4h, v0.h[2]\n" /* outr00 = v0 * r0[0] [v3] "w"(v3), [v4] "w"(v4), [v5] "w"(v5),
*/ [v6] "w"(v6), [v7] "w"(v7), [v8] "w"(v8)
"smlal2 v5.4s, %[v5].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] : "cc", "memory", "v0", "v1", "v2", "v3",
*/ "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"smlal v6.4s, %[v5].4h, v0.h[4]\n" /* outr01 = v0 * r0[2] "v11", "v12", "v13", "v14", "v15", "v16",
*/ "v17", "v18", "v19", "v20", "v21", "v22"
"smlal2 v7.4s, %[v5].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] );
*/ // clang-format on
"sshll v3.8h, v3.8b, #0 \n" /* r3 : int8 -> int16 */
"smlal v8.4s, %[v5].4h, v0.h[6]\n" /* outr02 = v0 * r0[4]
*/
"smlal2 v9.4s, %[v5].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]
*/
"smlal v10.4s, %[v5].4h, v2.h[0]\n" /* outr03 = v0 * r0[6]
*/
"smlal2 v11.4s, %[v5].8h, v2.h[0]\n" /* outr00 = v0 * r0[0]
*/
"smlal v12.4s, %[v5].4h, v1.h[2]\n" /* outr10 = v0 * r2[0]
*/
"smlal2 v13.4s, %[v5].8h, v1.h[2]\n" /* outr11 = v0 * r2[2]
*/
"smlal v14.4s, %[v5].4h, v1.h[4]\n" /* outr12 = v0 * r2[4]
*/
"smlal2 v15.4s, %[v5].8h, v1.h[4]\n" /* outr13 = v0 * r2[6]
*/
"smlal v16.4s, %[v5].4h, v1.h[6]\n" /* outr10 = v0 * r2[0]
*/
"smlal2 v17.4s, %[v5].8h, v1.h[6]\n" /* outr11 = v0 * r2[2]
*/
"smlal v18.4s, %[v5].4h, v3.h[0]\n" /* outr12 = v0 * r2[4]
*/
"smlal2 v19.4s, %[v5].8h, v3.h[0]\n" /* outr13 = v0 * r2[6]
*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */
/* r1, r3, mul w04 */
"smlal v4.4s, %[v4].4h, v0.h[1]\n" /* outr00 = v0 * r0[0]
*/
"smlal2 v5.4s, %[v4].8h, v0.h[1]\n" /* outr00 = v0 * r0[0]
*/
"smlal v6.4s, %[v4].4h, v0.h[3]\n" /* outr01 = v0 * r0[2]
*/
"smlal2 v7.4s, %[v4].8h, v0.h[3]\n" /* outr00 = v0 * r0[0]
*/
"smlal v8.4s, %[v4].4h, v0.h[5]\n" /* outr02 = v0 * r0[4]
*/
"smlal2 v9.4s, %[v4].8h, v0.h[5]\n" /* outr00 = v0 * r0[0]
*/
"smlal v10.4s, %[v4].4h, v0.h[7]\n" /* outr03 = v0 * r0[6]
*/
"smlal2 v11.4s, %[v4].8h, v0.h[7]\n" /* outr00 = v0 * r0[0]
*/
"ldr q0, [%[r4]], #8 \n" /* load input r4 */
"smlal v12.4s, %[v4].4h, v1.h[1]\n" /* outr10 = v0 * r2[0]
*/
"smlal2 v13.4s, %[v4].8h, v1.h[1]\n" /* outr11 = v0 * r2[2]
*/
"smlal v14.4s, %[v4].4h, v1.h[3]\n" /* outr12 = v0 * r2[4]
*/
"smlal2 v15.4s, %[v4].8h, v1.h[3]\n" /* outr13 = v0 * r2[6]
*/
"sshll v0.8h, v0.8b, #0 \n" /* r4 : int8 -> int16 */
"smlal v16.4s, %[v4].4h, v1.h[5]\n" /* outr10 = v0 * r2[0]
*/
"smlal2 v17.4s, %[v4].8h, v1.h[5]\n" /* outr11 = v0 * r2[2]
*/
"smlal v18.4s, %[v4].4h, v1.h[7]\n" /* outr12 = v0 * r2[4]
*/
"smlal2 v19.4s, %[v4].8h, v1.h[7]\n" /* outr13 = v0 * r2[6]
*/
"ldr q2, [%[r4]] \n" /* load r4, 9th
data,v10.s[0] */
"sshll v2.8h, v2.8b, #0 \n" /* r4 : int8 -> int16 */
"ldp q1, q3, [%[ptr_out0]] \n" /* load ptr_out + 0 ->
q2, q3 */
"ldp q20, q21, [%[ptr_out0], #32]\n" /* load ptr_out + 32 ->
q4, q5 */
"add v4.4s, v1.4s , v4.4s \n" /* v10 = outr00[0].low
+ q2 */
"add v5.4s, v3.4s , v5.4s \n" /* v11 = outr00[0].high
+ q3 */
"add v6.4s, v20.4s, v6.4s \n" /* v12 = outr01[0].low
+ q4 */
"add v7.4s, v21.4s, v7.4s \n" /* v13 = outr01[0].high
+ q5 */
"ldp q1 , q3 , [%[ptr_out0], #64]\n" /* load ptr_out + 64 ->
q6, q7 */
"ldp q20, q21, [%[ptr_out0], #96]\n" /* load ptr_out + 96 ->
q8, q9 */
"stp q4, q5 , [%[ptr_out0]], #32\n" /* store q10, q11 ->
ptr_out */
"stp q6, q7 , [%[ptr_out0]], #32\n" /* store q10, q11 ->
ptr_out */
"add v8.4s , v1.4s , v8.4s \n" /* v10 = outr00[0].low
+ q2 */
"add v9.4s , v3.4s , v9.4s \n" /* v11 = outr00[0].high
+ q3 */
"add v10.4s, v20.4s, v10.4s \n" /* v12 = outr01[0].low
+ q4 */
"add v11.4s, v21.4s, v11.4s \n" /* v13 = outr01[0].high
+ q5 */
"stp q8, q9, [%[ptr_out0]], #32\n" /* store q14, q15 ->
ptr_out += 64 */
"stp q10, q11, [%[ptr_out0]], #32\n" /* store q16, q17 ->
ptr_out += 96 */
/* r4, mul w08 */
"smlal v12.4s, %[v8].4h, v0.h[2]\n" /* outr00 = v0 * r0[0]
*/
"smlal2 v13.4s, %[v8].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]
*/
"smlal v14.4s, %[v8].4h, v0.h[4]\n" /* outr01 = v0 * r0[2]
*/
"smlal2 v15.4s, %[v8].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]
*/
"smlal v16.4s, %[v8].4h, v0.h[6]\n" /* outr02 = v0 * r0[4]
*/
"smlal2 v17.4s, %[v8].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]
*/
"smlal v18.4s, %[v8].4h, v2.h[0]\n" /* outr03 = v0 * r0[6]
*/
"smlal2 v19.4s, %[v8].8h, v2.h[0]\n" /* outr00 = v0 * r0[0]
*/
/* r4, mul w07 */
"smlal v12.4s, %[v7].4h, v0.h[1]\n" /* outr00 = v0 * r0[0]
*/
"smlal2 v13.4s, %[v7].8h, v0.h[1]\n" /* outr00 = v0 * r0[0]
*/
"smlal v14.4s, %[v7].4h, v0.h[3]\n" /* outr01 = v0 * r0[2]
*/
"smlal2 v15.4s, %[v7].8h, v0.h[3]\n" /* outr00 = v0 * r0[0]
*/
"ldr q1, [%[r2]], #8 \n" /* load input r2 */
"smlal v16.4s, %[v7].4h, v0.h[5]\n" /* outr02 = v0 * r0[4]
*/
"smlal2 v17.4s, %[v7].8h, v0.h[5]\n" /* outr00 = v0 * r0[0]
*/
"smlal v18.4s, %[v7].4h, v0.h[7]\n" /* outr03 = v0 * r0[6]
*/
"smlal2 v19.4s, %[v7].8h, v0.h[7]\n" /* outr00 = v0 * r0[0]
*/
"sshll v1.8h, v1.8b, #0 \n" /* r2: int8 -> int16
*/
/* r4, mul w06 */
"ldp q4, q5, [%[ptr_out1]] \n" /* load ptr_out + 0 ->
q2, q3 */
"smlal v12.4s, %[v6].4h, v0.h[0]\n" /* outr00 = v0 * r0[0]
*/
"smlal2 v13.4s, %[v6].8h, v0.h[0]\n" /* outr00 = v0 * r0[0]
*/
"smlal v14.4s, %[v6].4h, v0.h[2]\n" /* outr01 = v0 * r0[2]
*/
"ldp q8, q9, [%[ptr_out1], #64]\n" /* load ptr_out + 64 ->
q6, q7 */
"smlal2 v15.4s, %[v6].8h, v0.h[2]\n" /* outr00 = v0 * r0[0]
*/
"smlal v16.4s, %[v6].4h, v0.h[4]\n" /* outr02 = v0 * r0[4]
*/
"smlal2 v17.4s, %[v6].8h, v0.h[4]\n" /* outr00 = v0 * r0[0]
*/
"ldp q10, q11, [%[ptr_out1], #96]\n" /* load ptr_out + 96 ->
q8, q9 */
"smlal v18.4s, %[v6].4h, v0.h[6]\n" /* outr03 = v0 * r0[6]
*/
"smlal2 v19.4s, %[v6].8h, v0.h[6]\n" /* outr00 = v0 * r0[0]
*/
"ldr q0, [%[r0]], #8 \n" /* load input r2 */
"ldp q6, q7, [%[ptr_out1], #32]\n" /* load ptr_out + 32 ->
q4, q5 */
"sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */
/* store outr1 */
"add v12.4s, v4.4s , v12.4s\n" /* v10 = outr10[0].low + q2 */
"add v13.4s, v5.4s , v13.4s\n" /* v11 = outr10[0].high + q3 */
"add v14.4s, v6.4s , v14.4s\n" /* v12 = outr11[0].low + q4 */
"add v15.4s, v7.4s , v15.4s\n" /* v13 = outr11[0].high + q5 */
"stp q12, q13, [%[ptr_out1]], #32\n" /* store q10, q11 ->
ptr_out */
"add v16.4s, v8.4s , v16.4s\n" /* v14 = outr12[0].low + q6 */
"add v17.4s, v9.4s , v17.4s\n" /* v15 = outr12[0].high + q7 */
"stp q14, q15, [%[ptr_out1]], #32\n" /* store q12, q13 ->
ptr_out += 32 */
"add v18.4s, v10.4s, v18.4s\n" /* v16 = outr13[0].low + q8 */
"add v19.4s, v11.4s, v19.4s\n" /* v17 = outr13[0].high + q9 */
"stp q16, q17, [%[ptr_out1]], #32\n" /* store q14, q15 ->
ptr_out += 64 */
"stp q18, q19, [%[ptr_out1]], #32\n" /* store q16, q17 ->
ptr_out += 96 */
"bne 1b \n" /* jump to main loop */
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [v0] "w"(v0),
[v1] "w"(v1),
[v2] "w"(v2),
[v3] "w"(v3),
[v4] "w"(v4),
[v5] "w"(v5),
[v6] "w"(v6),
[v7] "w"(v7),
[v8] "w"(v8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
wc0 += 9 * hout_c_block; wc0 += 9 * hout_c_block;
inr0 += win_round; inr0 += win_round;
inr1 += win_round; inr1 += win_round;
...@@ -683,47 +430,8 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -683,47 +430,8 @@ void conv_3x3s2_direct_int8(const int8_t* din,
block_inr3 = block_inr2 + in_len; block_inr3 = block_inr2 + in_len;
block_inr4 = block_inr3 + in_len; block_inr4 = block_inr3 + in_len;
} }
if (out_type == PRECISION(kFloat)) { write_int32_nchwc8_to_nchw(pre_out,
write_to_output_c8_int32_1(pre_out, dout_batch,
reinterpret_cast<float*>(dout_batch),
hout_c_block,
2,
c,
c + hout_c_block,
h,
h + h_kernel,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<float*>(ptr_write),
&scale[c],
out_type);
} else if (out_type == PRECISION(kInt8)) {
write_to_output_c8_int32_1(pre_out,
dout_batch,
hout_c_block,
2,
c,
c + hout_c_block,
h,
h + h_kernel,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<signed char*>(ptr_write),
&scale[c],
out_type);
} else {
write_to_output_c8_int32(pre_out,
reinterpret_cast<int*>(dout_batch),
hout_c_block,
2,
c, c,
c + hout_c_block, c + hout_c_block,
h, h,
...@@ -734,8 +442,10 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -734,8 +442,10 @@ void conv_3x3s2_direct_int8(const int8_t* din,
hout, hout,
wout, wout,
flag_relu, flag_relu,
ptr_write); bias_local,
} flag_bias,
ptr_write,
scale + c);
} }
} }
} }
...@@ -743,8 +453,10 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -743,8 +453,10 @@ void conv_3x3s2_direct_int8(const int8_t* din,
#else // __aarch64__ #else // __aarch64__
int conv_3x3s2_direct_int8_c_num() { return 4; } int conv_3x3s2_direct_int8_c_num() { return 4; }
template <typename Dtype>
void conv_3x3s2_direct_int8(const int8_t* din, void conv_3x3s2_direct_int8(const int8_t* din,
int32_t* dout, Dtype* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
...@@ -753,27 +465,24 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -753,27 +465,24 @@ void conv_3x3s2_direct_int8(const int8_t* din,
int hin, int hin,
int win, int win,
const int8_t* weights, const int8_t* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, Context<TARGET(kARM)>* ctx,
PrecisionType out_type,
const float* scale) { const float* scale) {
//! 3x3s2 int8 convolution, implemented by direct algorithm //! 3x3s2 int8 convolution, implemented by direct algorithm
//! prepack input to tmp buffer //! prepack input to tmp buffer
//! write output to tmp buffer //! write output to tmp buffer
int threads = ctx->threads();
int stride_w = param.strides[1];
int pad_w = param.paddings[1];
int pad_h = param.paddings[0];
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = (param.bias != nullptr); bool flag_bias = param.bias;
int pad_h = param.paddings[0];
//! set 2/3 l2 cache int pad_w = param.paddings[1];
int l2_size = ctx->llc_size() / 3 * 2; const int threads = ctx->threads();
//! set 1/4 l2 cache
int llc_size = ctx->llc_size() / 4;
const int hout_c_block = 4; const int hout_c_block = 4;
const int hout_r_kernel = 1; const int hout_r_kernel = 1;
const int wout_round = ((wout + 3) / 4) * 4; const int wout_round = ((wout + 3) / 4) * 4;
const int win_round = wout_round * stride_w + 1; const int win_round = wout_round * 2 /*stride_w*/ + 1;
//! get h block //! get h block
//! win_round * chin * hin_r_block * sizeof(int8_t) + wout_round * //! win_round * chin * hin_r_block * sizeof(int8_t) + wout_round *
...@@ -781,7 +490,7 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -781,7 +490,7 @@ void conv_3x3s2_direct_int8(const int8_t* din,
//! win_round = 2 * wout_round + 1 //! win_round = 2 * wout_round + 1
//! hin_r_block = 2 * hout_r_block + 1 //! hin_r_block = 2 * hout_r_block + 1
int hout_r_block = int hout_r_block =
(l2_size - 2 * wout_round * chin - chin) / (llc_size - 2 * wout_round * chin - chin) /
((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads * 4); ((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel;
...@@ -789,16 +498,15 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -789,16 +498,15 @@ void conv_3x3s2_direct_int8(const int8_t* din,
const int hin_r_block = hout_r_block * 2 + 1; const int hin_r_block = hout_r_block * 2 + 1;
int8_t* tmp_work_space = ctx->workspace_data<int8_t>(); auto tmp_work_space = ctx->workspace_data<int8_t>();
int zero_size = chout > (win_round + 3) / 4 ? chout : (win_round + 3) / 4; int zero_size = chout > (win_round + 3) / 4 ? chout : (win_round + 3) / 4;
const int kZeroSize = zero_size; int32_t ptr_zero[zero_size]; // NOLINT
int32_t ptr_zero[kZeroSize];
memset(ptr_zero, 0, sizeof(int32_t) * zero_size); memset(ptr_zero, 0, sizeof(int32_t) * zero_size);
const int kWoutRound = wout_round; Dtype ptr_write[wout_round]; // NOLINT
int32_t ptr_write[kWoutRound];
int in_len = win_round * chin; int in_len = win_round * chin;
int pre_in_size = hin_r_block * in_len; int pre_in_size = hin_r_block * in_len;
pre_in_size = ROUNDUP(pre_in_size, 4);
int pre_out_size = hout_c_block * hout_r_block * wout_round; int pre_out_size = hout_c_block * hout_r_block * wout_round;
//! l2_cache start //! l2_cache start
...@@ -815,10 +523,9 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -815,10 +523,9 @@ void conv_3x3s2_direct_int8(const int8_t* din,
int out_row_stride = hout_c_block * wout_round; int out_row_stride = hout_c_block * wout_round;
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const int8_t* din_batch = din + n * chin * size_in_channel; const int8_t* din_batch =
int8_t* dout_batch = static_cast<const int8_t*>(din) + n * chin * size_in_channel;
reinterpret_cast<int8_t*>(dout) + auto dout_batch = dout + n * chout * size_out_channel;
n * chout * size_out_channel * PrecisionTypeLength(out_type);
for (int h = 0; h < hout; h += hout_r_block) { for (int h = 0; h < hout; h += hout_r_block) {
int h_kernel = hout_r_block; int h_kernel = hout_r_block;
if (h + hout_r_block > hout) { if (h + hout_r_block > hout) {
...@@ -845,24 +552,23 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -845,24 +552,23 @@ void conv_3x3s2_direct_int8(const int8_t* din,
#pragma omp parallel for num_threads(threads) #pragma omp parallel for num_threads(threads)
for (int c = 0; c < chout; c += hout_c_block) { for (int c = 0; c < chout; c += hout_c_block) {
#ifdef ARM_WITH_OMP #ifdef ARM_WITH_OMP
int32_t* pre_out = int32_t* pre_out = reinterpret_cast<int*>(pre_din + pre_in_size) +
reinterpret_cast<int*>(pre_din + (pre_in_size + 3) / 4 * 4) + omp_get_thread_num() * pre_out_size;
omp_get_thread_num() * pre_out_size;
#else #else
int32_t* pre_out = int32_t* pre_out = reinterpret_cast<int32_t*>(pre_din + pre_in_size);
reinterpret_cast<int32_t*>(pre_din + (pre_in_size + 3) / 4 * 4);
#endif #endif
const int8_t* block_inr0 = cblock_inr0; const int8_t* block_inr0 = cblock_inr0;
const int8_t* block_inr1 = cblock_inr1; const int8_t* block_inr1 = cblock_inr1;
const int8_t* block_inr2 = cblock_inr2; const int8_t* block_inr2 = cblock_inr2;
const int8_t* weight_c = weights + c * w_stride; const int8_t* weight_c = weights + c * w_stride;
const int32_t* bias_ptr = ptr_zero; float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) { if (flag_bias) {
bias_ptr = bias + c; bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
} }
memset(pre_out, 0, pre_out_size * sizeof(int32_t));
fill_packed_bias_nxmw_int8(bias_ptr, pre_out, 4, h_kernel, wout_round);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
const int8_t* wc0 = weight_c; const int8_t* wc0 = weight_c;
...@@ -879,134 +585,97 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -879,134 +585,97 @@ void conv_3x3s2_direct_int8(const int8_t* din,
int32_t* ptr_out0 = pre_out0; int32_t* ptr_out0 = pre_out0;
const signed char* ptr_wc0 = wc0; const signed char* ptr_wc0 = wc0;
int cnt = w_loop; int cnt = w_loop;
// clang-format off
asm volatile( asm volatile(
"vld1.s32 {d0-d3}, [%[wc0]]! \n" /* w0-w7 */ "vld1.s32 {d0-d3}, [%[wc0]]! \n" /* w0-w7 */
"vld1.s32 {d4}, [%[wc0]]! \n" /* w8 */ "vld1.s32 {d4}, [%[wc0]]! \n" /* w8 */
"vmovl.s8 q3, d0 \n" /* q3 = w0, w1 */ "vmovl.s8 q3, d0 \n" /* q3 = w0, w1 */
"vmovl.s8 q4, d1 \n" /* q4 = w2 ,w3 */ "vmovl.s8 q4, d1 \n" /* q4 = w2 ,w3 */
"vmovl.s8 q5, d2 \n" /* q5 = w4, w5 */ "vmovl.s8 q5, d2 \n" /* q5 = w4, w5 */
"vmovl.s8 q6, d3 \n" /* q6 = w6, w7 */ "vmovl.s8 q6, d3 \n" /* q6 = w6, w7 */
"vmovl.s8 q7, d4 \n" /* q7 = w8 */ "vmovl.s8 q7, d4 \n" /* q7 = w8 */
"vld1.s32 {d0}, [%[r0]]! \n" /* load input r0 -> d0 */ "vld1.s32 {d0}, [%[r0]]! \n" /* load input r0, d0 */
"vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */
"1: \n" /* main loop */ "1: \n" /* main loop */
/* r0 mul w0 */
/* r0 mul w0 */ "vmull.s16 q8, d6, d0[0] \n" /* q8 = w0 * r0[0] */
"vmull.s16 q8, d6, d0[0] \n" /* q8 = w0 * r0[0] */ "vmull.s16 q9, d6, d0[2] \n" /* q9 = w0 * r0[2] */
"vmull.s16 q9, d6, d0[2] \n" /* q9 = w0 * r0[2] */ "vmull.s16 q10, d6, d1[0] \n" /* q10 = w0 * r0[4] */
"vmull.s16 q10, d6, d1[0] \n" /* q10 = w0 * r0[4] */ "vmull.s16 q11, d6, d1[2] \n" /* q11 = w0 * r0[6] */
"vmull.s16 q11, d6, d1[2] \n" /* q11 = w0 * r0[6] */ "vld1.s32 {d2}, [%[r1]]! \n" /* load input r1, d2 */
"vmovl.s8 q1, d2 \n" /* movl d2 -> q1 */
"vld1.s32 {d2}, [%[r1]]! \n" /* load input r1 -> d2 */ /* r0 mul w1 */
"vmovl.s8 q1, d2 \n" /* movl d2 -> q1 */ "vmlal.s16 q8, d7, d0[1] \n" /* q8 = w1 * r0[1] */
"vmlal.s16 q9, d7, d0[3] \n" /* q9 = w1 * r0[3] */
/* r0 mul w1 */ "vmlal.s16 q10, d7, d1[1] \n" /* q10 = w1 * r0[5] */
"vmlal.s16 q8, d7, d0[1] \n" /* q8 = w1 * r0[1] */ "vmlal.s16 q11, d7, d1[3] \n" /* q11 = w1 * r0[7] */
"vmlal.s16 q9, d7, d0[3] \n" /* q9 = w1 * r0[3] */ "vld1.s32 {d4}, [%[r0]] \n" /* load r0[8] -> d4 */
"vmlal.s16 q10, d7, d1[1] \n" /* q10 = w1 * r0[5] */ "vmovl.s8 q2 , d4 \n" /* movl d4 -> q2 */
"vmlal.s16 q11, d7, d1[3] \n" /* q11 = w1 * r0[7] */ /* r0 mul w2 */
"vmlal.s16 q8, d8, d0[2] \n" /* q8 = w2 * r0[2] */
"vld1.s32 {d4}, [%[r0]] \n" /* load r0[8] -> d4 */ "vmlal.s16 q9, d8, d1[0] \n" /* q9 = w2 * r0[4] */
"vmovl.s8 q2 , d4 \n" /* movl d4 -> q2 */ "vmlal.s16 q10, d8, d1[2] \n" /* q10 = w2 * r0[6] */
"vmlal.s16 q11, d8, d4[0] \n" /* q11 = w2 * r0[8] */
/* r0 mul w2 */ "subs %[cnt], #1 \n" /* loop count -1 */
"vmlal.s16 q8, d8, d0[2] \n" /* q8 = w2 * r0[2] */ /* r1 mul w3 */
"vmlal.s16 q9, d8, d1[0] \n" /* q9 = w2 * r0[4] */ "vmlal.s16 q8, d9, d2[0] \n" /* q8 = w3 * r1[0] */
"vmlal.s16 q10, d8, d1[2] \n" /* q10 = w2 * r0[6] */ "vmlal.s16 q9, d9, d2[2] \n" /* q9 = w3 * r1[2] */
"vmlal.s16 q11, d8, d4[0] \n" /* q11 = w2 * r0[8] */ "vmlal.s16 q10, d9, d3[0] \n" /* q10 = w3 * r1[4] */
"vmlal.s16 q11, d9, d3[2] \n" /* q11 = w3 * r1[6] */
"subs %[cnt], #1 \n" /* loop count -1 */ "vld1.s32 {d4}, [%[r2]]! \n" /* load input r2, d4*/
"vmovl.s8 q2, d4 \n" /* movl d4 -> q2 */
/* r1 mul w3 */ /* r1 mul w4 */
"vmlal.s16 q8, d9, d2[0] \n" /* q8 = w3 * r1[0] */ "vmlal.s16 q8, d10, d2[1] \n" /* q8 = w4 * r1[1] */
"vmlal.s16 q9, d9, d2[2] \n" /* q9 = w3 * r1[2] */ "vmlal.s16 q9, d10, d2[3] \n" /* q9 = w4 * r1[3] */
"vmlal.s16 q10, d9, d3[0] \n" /* q10 = w3 * r1[4] */ "vmlal.s16 q10, d10, d3[1] \n" /* q10 = w4 * r1[5] */
"vmlal.s16 q11, d9, d3[2] \n" /* q11 = w3 * r1[6] */ "vmlal.s16 q11, d10, d3[3] \n" /* q11 = w4 * r1[7] */
"vld1.s32 {d0}, [%[r1]] \n" /* load r1[8] -> d0 */
"vld1.s32 {d4}, [%[r2]]! \n" /* load input r2 -> d4*/ "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */
"vmovl.s8 q2, d4 \n" /* movl d4 -> q2 */ /* r1 mul w5 */
"vmlal.s16 q8, d11, d2[2] \n" /* q8 = w5 * r1[2] */
/* r1 mul w4 */ "vmlal.s16 q9, d11, d3[0] \n" /* q9 = w5 * r1[4] */
"vmlal.s16 q8, d10, d2[1] \n" /* q8 = w4 * r1[1] */ "vmlal.s16 q10, d11, d3[2] \n" /* q10 = w5 * r1[6] */
"vmlal.s16 q9, d10, d2[3] \n" /* q9 = w4 * r1[3] */ "vmlal.s16 q11, d11, d0[0] \n" /* q11 = w5 * r1[8] */
"vmlal.s16 q10, d10, d3[1] \n" /* q10 = w4 * r1[5] */ /* r2 mul w6 */
"vmlal.s16 q11, d10, d3[3] \n" /* q11 = w4 * r1[7] */ "vmlal.s16 q8, d12, d4[0] \n" /* q8 = w6 * r2[0] */
"vmlal.s16 q9, d12, d4[2] \n" /* q9 = w6 * r2[2] */
"vld1.s32 {d0}, [%[r1]] \n" /* load r1[8] -> d0 */ "vmlal.s16 q10, d12, d5[0] \n" /* q10 = w6 * r2[4] */
"vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ "vmlal.s16 q11, d12, d5[2] \n" /* q11 = w6 * r2[6] */
"vld1.s32 {d24-d27}, [%[ptr_out0]] \n" /* load output, q12,q13 */
/* r1 mul w5 */ /* r2 mul w7 */
"vmlal.s16 q8, d11, d2[2] \n" /* q8 = w5 * r1[2] */ "vmlal.s16 q8, d13, d4[1] \n" /* q8 = w7 * r2[1] */
"vmlal.s16 q9, d11, d3[0] \n" /* q9 = w5 * r1[4] */ "vmlal.s16 q9, d13, d4[3] \n" /* q9 = w7 * r2[3] */
"vmlal.s16 q10, d11, d3[2] \n" /* q10 = w5 * r1[6] */ "vmlal.s16 q10, d13, d5[1] \n" /* q10 = w7 * r2[5] */
"vmlal.s16 q11, d11, d0[0] \n" /* q11 = w5 * r1[8] */ "vmlal.s16 q11, d13, d5[3] \n" /* q11 = w7 * r2[7] */
"vld1.s32 {d0}, [%[r2]] \n" /* load r2[8] -> d0 */
/* r2 mul w6 */ "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */
"vmlal.s16 q8, d12, d4[0] \n" /* q8 = w6 * r2[0] */ /* r2 mul w8 */
"vmlal.s16 q9, d12, d4[2] \n" /* q9 = w6 * r2[2] */ "vmlal.s16 q8, d14, d4[2] \n" /* q8 = w8 * r2[2] */
"vmlal.s16 q10, d12, d5[0] \n" /* q10 = w6 * r2[4] */ "vmlal.s16 q9, d14, d5[0] \n" /* q9 = w8 * r2[4] */
"vmlal.s16 q11, d12, d5[2] \n" /* q11 = w6 * r2[6] */ "vmlal.s16 q10, d14, d5[2] \n" /* q10 = w8 * r2[6] */
"vmlal.s16 q11, d14, d0[0] \n" /* q11 = w8 * r2[8] */
"vld1.s32 {d24-d27}, [%[ptr_out0]] \n" /* load output -> q12, "vadd.s32 q12, q8, q12 \n" /* out[0] += q8 */
q13 */ "vadd.s32 q13, q9, q13 \n" /* out[1] += q9 */
"vst1.s32 {d24-d27}, [%[ptr_out0]]! \n" /* store output[0,1]*/
/* r2 mul w7 */ "vld1.s32 {d0}, [%[r0]]! \n" /* load next input r0, d0*/
"vmlal.s16 q8, d13, d4[1] \n" /* q8 = w7 * r2[1] */ "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */
"vmlal.s16 q9, d13, d4[3] \n" /* q9 = w7 * r2[3] */ "vld1.s32 {d28-d31}, [%[ptr_out0]] \n" /* load output[0,1]*/
"vmlal.s16 q10, d13, d5[1] \n" /* q10 = w7 * r2[5] */ "vadd.s32 q14, q10, q14 \n" /* out[2] += q10 */
"vmlal.s16 q11, d13, d5[3] \n" /* q11 = w7 * r2[7] */ "vadd.s32 q15, q11, q15 \n" /* out[3] += q11 */
"vst1.s32 {d28-d31}, [%[ptr_out0]]! \n" /* store output[2,3] */
"vld1.s32 {d0}, [%[r2]] \n" /* load r2[8] -> d0 */ "bne 1b \n" /* jump to main loop */
"vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ : [cnt] "+r"(cnt),
[r0] "+r"(r0),
/* r2 mul w8 */ [r1] "+r"(r1),
"vmlal.s16 q8, d14, d4[2] \n" /* q8 = w8 * r2[2] */ [r2] "+r"(r2),
"vmlal.s16 q9, d14, d5[0] \n" /* q9 = w8 * r2[4] */ [ptr_out0] "+r"(ptr_out0),
"vmlal.s16 q10, d14, d5[2] \n" /* q10 = w8 * r2[6] */ [wc0] "+r"(ptr_wc0)
"vmlal.s16 q11, d14, d0[0] \n" /* q11 = w8 * r2[8] */ :
: "cc", "memory", "q0", "q1", "q2", "q3",
"vadd.s32 q12, q8, q12 \n" /* out[0] += q8 */ "q4", "q5", "q6", "q7", "q8", "q9",
"vadd.s32 q13, q9, q13 \n" /* out[1] += q9 */ "q10", "q11", "q12", "q13", "q14", "q15"
"vst1.s32 {d24-d27}, [%[ptr_out0]]! \n" /* store q12, q13 -> );
output[0,1] */ // clang-format on
"vld1.s32 {d0}, [%[r0]]! \n" /* load next input r0 -> d0*/
"vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */
"vld1.s32 {d28-d31}, [%[ptr_out0]] \n" /* load output[0,1] ->
q14, q15 */
"vadd.s32 q14, q10, q14 \n" /* out[2] += q10 */
"vadd.s32 q15, q11, q15 \n" /* out[3] += q11 */
"vst1.s32 {d28-d31}, [%[ptr_out0]]! \n" /* store q14, q15 ->
output[2,3] */
"bne 1b \n" /* jump to main loop */
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[ptr_out0] "+r"(ptr_out0),
[wc0] "+r"(ptr_wc0)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
wc0 += 9 * hout_c_block; wc0 += 9 * hout_c_block;
inr0 += win_round; inr0 += win_round;
inr1 += win_round; inr1 += win_round;
...@@ -1016,47 +685,8 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -1016,47 +685,8 @@ void conv_3x3s2_direct_int8(const int8_t* din,
block_inr1 = block_inr0 + in_len; block_inr1 = block_inr0 + in_len;
block_inr2 = block_inr1 + in_len; block_inr2 = block_inr1 + in_len;
} }
if (out_type == PRECISION(kFloat)) { write_int32_nchwc4_to_nchw(pre_out,
write_to_output_c4_int32_1(pre_out, dout_batch,
reinterpret_cast<float*>(dout_batch),
hout_c_block,
1,
c,
c + hout_c_block,
h,
h + h_kernel,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<float*>(ptr_write),
&scale[c],
out_type);
} else if (out_type == PRECISION(kInt8)) {
write_to_output_c4_int32_1(pre_out,
dout_batch,
hout_c_block,
1,
c,
c + hout_c_block,
h,
h + h_kernel,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<signed char*>(ptr_write),
&scale[c],
out_type);
} else {
write_to_output_c4_int32(pre_out,
reinterpret_cast<int*>(dout_batch),
hout_c_block,
1,
c, c,
c + hout_c_block, c + hout_c_block,
h, h,
...@@ -1067,14 +697,46 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -1067,14 +697,46 @@ void conv_3x3s2_direct_int8(const int8_t* din,
hout, hout,
wout, wout,
flag_relu, flag_relu,
ptr_write); bias_local,
} flag_bias,
ptr_write,
scale + c);
} }
} }
} }
} }
#endif // __aarch64__ #endif // __aarch64__
template void conv_3x3s2_direct_int8(const int8_t* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int8_t* weights,
const float* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx,
const float* scale);
template void conv_3x3s2_direct_int8(const int8_t* din,
int8_t* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int8_t* weights,
const float* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx,
const float* scale);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/backends/arm/math/conv_depthwise.h"
#include <arm_neon.h> #include <arm_neon.h>
#include "lite/backends/arm/math/conv_depthwise.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -5073,7 +5073,7 @@ void conv_depthwise_5x5s1_small_impl(const float* din, ...@@ -5073,7 +5073,7 @@ void conv_depthwise_5x5s1_small_impl(const float* din,
int w_in_new = w_in + 2 * pad_new; int w_in_new = w_in + 2 * pad_new;
int h_out_new = h_out - 2 * pad_0; int h_out_new = h_out - 2 * pad_0;
int w_out_new = w_out - 2 * pad_0; int w_out_new = w_out - 2 * pad_0;
float zero_ptr[w_in_new + w_out]; float zero_ptr[w_in_new + w_out]; // NOLINT
memset(zero_ptr, 0, w_in_new * sizeof(float)); memset(zero_ptr, 0, w_in_new * sizeof(float));
float* write_ptr = zero_ptr + w_in_new; float* write_ptr = zero_ptr + w_in_new;
int pad_cnt = pad_0 >> 2; int pad_cnt = pad_0 >> 2;
...@@ -5320,7 +5320,7 @@ void conv_depthwise_5x5s1_small_relu_impl(const float* din, ...@@ -5320,7 +5320,7 @@ void conv_depthwise_5x5s1_small_relu_impl(const float* din,
int pad_0 = pad - pad_new; int pad_0 = pad - pad_new;
int h_in_new = h_in + 2 * pad_new; int h_in_new = h_in + 2 * pad_new;
int w_in_new = w_in + 2 * pad_new; int w_in_new = w_in + 2 * pad_new;
float zero_ptr[w_in_new + w_out]; float zero_ptr[w_in_new + w_out]; // NOLINT
memset(zero_ptr, 0, w_in_new * sizeof(float)); memset(zero_ptr, 0, w_in_new * sizeof(float));
float* write_ptr = zero_ptr + w_in_new; float* write_ptr = zero_ptr + w_in_new;
int h_out_new = h_out - 2 * pad_0; int h_out_new = h_out - 2 * pad_0;
...@@ -9177,7 +9177,7 @@ void conv_depthwise_5x5s1_small_impl(const float* din, ...@@ -9177,7 +9177,7 @@ void conv_depthwise_5x5s1_small_impl(const float* din,
int w_in_new = w_in + 2 * pad_new; int w_in_new = w_in + 2 * pad_new;
int h_out_new = h_out - 2 * pad_0; int h_out_new = h_out - 2 * pad_0;
int w_out_new = w_out - 2 * pad_0; int w_out_new = w_out - 2 * pad_0;
float zero_ptr[w_in_new + w_out]; float zero_ptr[w_in_new + w_out]; // NOLINT
memset(zero_ptr, 0, w_in_new * sizeof(float)); memset(zero_ptr, 0, w_in_new * sizeof(float));
float* write_ptr = zero_ptr + w_in_new; float* write_ptr = zero_ptr + w_in_new;
int pad_cnt = pad_0 >> 2; int pad_cnt = pad_0 >> 2;
...@@ -9359,7 +9359,7 @@ void conv_depthwise_5x5s1_small_relu_impl(const float* din, ...@@ -9359,7 +9359,7 @@ void conv_depthwise_5x5s1_small_relu_impl(const float* din,
int w_in_new = w_in + 2 * pad_new; int w_in_new = w_in + 2 * pad_new;
int h_out_new = h_out - 2 * pad_0; int h_out_new = h_out - 2 * pad_0;
int w_out_new = w_out - 2 * pad_0; int w_out_new = w_out - 2 * pad_0;
float zero_ptr[w_in_new + w_out]; float zero_ptr[w_in_new + w_out]; // NOLINT
memset(zero_ptr, 0, w_in_new * sizeof(float)); memset(zero_ptr, 0, w_in_new * sizeof(float));
float* write_ptr = zero_ptr + w_in_new; float* write_ptr = zero_ptr + w_in_new;
int pad_cnt = pad_0 >> 2; int pad_cnt = pad_0 >> 2;
...@@ -9523,21 +9523,21 @@ void conv_depthwise_5x5s1_small_relu_impl(const float* din, ...@@ -9523,21 +9523,21 @@ void conv_depthwise_5x5s1_small_relu_impl(const float* din,
} }
#endif // __aarch64__ #endif // __aarch64__
void conv_depthwise_5x5s1(const float* din, void conv_depthwise_5x5s1_fp32(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
const float* weights, const float* weights,
const float* bias, const float* bias,
int pad, int pad,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
ARMContext* ctx) { ARMContext* ctx) {
if (win < 4) { if (win < 4) {
if (flag_relu) { if (flag_relu) {
conv_depthwise_5x5s1_small_relu_impl(din, conv_depthwise_5x5s1_small_relu_impl(din,
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/backends/arm/math/conv_depthwise.h"
#include <arm_neon.h> #include <arm_neon.h>
#include "lite/backends/arm/math/conv_depthwise.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -80,21 +80,21 @@ void conv_depthwise_5x5s2p2_relu_s(const float* din, ...@@ -80,21 +80,21 @@ void conv_depthwise_5x5s2p2_relu_s(const float* din,
bool flag_relu, bool flag_relu,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_5x5s2(const float* din, void conv_depthwise_5x5s2_fp32(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
const float* weights, const float* weights,
const float* bias, const float* bias,
int pad, int pad,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
ARMContext* ctx) { ARMContext* ctx) {
if (pad == 2) { if (pad == 2) {
if (win >= 9) { if (win >= 9) {
if (flag_relu) { if (flag_relu) {
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
#pragma once #pragma once
#include <arm_neon.h> #include <arm_neon.h>
#include <cmath> #include <cmath>
#include "lite/backends/arm/math/gemm_s8.h"
#include "lite/backends/arm/math/saturate.h" #include "lite/backends/arm/math/saturate.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/backends/arm/math/type_trans.h" #include "lite/backends/arm/math/type_trans.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
...@@ -26,6 +28,47 @@ namespace arm { ...@@ -26,6 +28,47 @@ namespace arm {
namespace math { namespace math {
#define LITEMAX(a, b) ((a) > (b) ? (a) : (b)) #define LITEMAX(a, b) ((a) > (b) ? (a) : (b))
#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
template <PrecisionType Ptype>
inline void trans_gemm_weights(const Tensor& tin,
Tensor& tout, // NOLINT
int group,
ARMContext* ctx);
template <>
inline void trans_gemm_weights<PRECISION(kFloat)>(const Tensor& tin,
Tensor& tout, // NOLINT
int group,
ARMContext* ctx) {
CHECK_EQ(tin.dims().size(), 4) << "conv weights dims size must = 4";
int m = tin.dims()[0] / group;
int k = tin.dims().count(1, 4);
int hblock = lite::arm::math::get_hblock(ctx);
int m_roundup = hblock * ((m + hblock - 1) / hblock);
int group_size_round_up = ((m_roundup * k + 15) / 16) * 16;
float* w_trans_ptr = nullptr;
tout.Resize({group_size_round_up * group});
w_trans_ptr = tout.mutable_data<float>();
const auto* w_data = tin.data<float>();
for (int g = 0; g < group; ++g) {
const float* weights_group = w_data + g * m * k;
float* weights_trans_ptr = w_trans_ptr + g * group_size_round_up;
lite::arm::math::prepackA(
weights_trans_ptr, weights_group, 1.f, k, 0, m, 0, k, false, ctx);
}
}
template <>
inline void trans_gemm_weights<PRECISION(kInt8)>(const Tensor& tin,
Tensor& tout, // NOLINT
int group,
ARMContext* ctx) {
CHECK_EQ(tin.dims().size(), 4) << "conv weights dims size must = 4";
int m = tin.dims()[0] / group;
int k = tin.dims().count(1, 4);
prepackA_int8(&tout, tin, m, k, group, false, ctx);
}
inline void fill_packed_biasc4(float* dout, const float* bias, int size) { inline void fill_packed_biasc4(float* dout, const float* bias, int size) {
float32x4_t vb = vld1q_f32(bias); float32x4_t vb = vld1q_f32(bias);
...@@ -159,6 +202,426 @@ static bool prepack_input_nxw(const dtype* din, ...@@ -159,6 +202,426 @@ static bool prepack_input_nxw(const dtype* din,
return true; return true;
} }
inline void transpose_4x4(float32x4_t v0,
float32x4_t v1,
float32x4_t v2,
float32x4_t v3,
float* dout) {
#ifdef __aarch64__
asm volatile(
"trn1 v0.4s, %[v0].4s, %[v1].4s\n" /* trans q0, q1, a0b0a2b2*/
"trn2 v1.4s, %[v0].4s, %[v1].4s\n" /* trans q0, q1, a1b1a3b3*/
"trn1 v2.4s, %[v2].4s, %[v3].4s\n" /* trans q2, q3, c0d0c2d2*/
"trn2 v3.4s, %[v2].4s, %[v3].4s\n" /* trans q2, q3, c1d1c3d3*/
"trn1 v4.2d, v0.2d, v2.2d\n" /* trans q0, q2, a0b0c0d0*/
"trn2 v6.2d, v0.2d, v2.2d\n" /* trans q0, q2, a2b2c2d2*/
"trn1 v5.2d, v1.2d, v3.2d\n" /* trans q1, q3, a1b1c1d1*/
"trn2 v7.2d, v1.2d, v3.2d\n" /* trans q1, q3, a3b3c3d3*/
"stp q4, q5, [%[dout]], #32\n"
"stp q6, q7, [%[dout]]\n"
: [dout] "+r"(dout)
: [v0] "w"(v0), [v1] "w"(v1), [v2] "w"(v2), [v3] "w"(v3)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
#else
asm volatile(
"vtrn.32 %q[v0], %q[v1]\n" /* trans q0, q1, a0b0a2b2, a1b1a3b3*/
"vtrn.32 %q[v2], %q[v3]\n" /* trans q2, q3, c0d0c2d2, c1d1c3d3*/
"vswp %f[v0], %e[v2]\n" /* trans q0, q2, a0b0c0d0, a2b2c2d2*/
"vswp %f[v1], %e[v3]\n" /* trans q1, q3, a1b1c1d1, a3b3c3d3*/
"vst1.32 {%e[v0], %f[v0]}, [%[dout]]!\n"
"vst1.32 {%e[v1], %f[v1]}, [%[dout]]!\n"
"vst1.32 {%e[v2], %f[v2]}, [%[dout]]!\n"
"vst1.32 {%e[v3], %f[v3]}, [%[dout]]\n"
: [dout] "+r"(dout)
: [v0] "w"(v0), [v1] "w"(v1), [v2] "w"(v2), [v3] "w"(v3)
:);
#endif
}
inline void prepack_input_nxwc4_dw(const float* din,
float* dout,
int cs,
int hs,
int he,
int ws,
int we,
int channel,
int width,
int height,
float* zero_ptr) {
int n = he - hs;
if (n <= 0) {
LOG(FATAL) << "prepack_dw_input, valid height must > zero";
}
float32x4_t vzero = vdupq_n_f32(0.f);
int size_w = we - ws;
int w0 = ws < 0 ? 0 : ws;
int w1 = we > width ? width : we;
int valid_w = w1 - w0;
int mask[4] = {0, 1, 2, 3};
int pad_l = ws < 0 ? -ws : 0;
int pad_r = we > width ? we - width : 0;
int cnt_l = pad_l / 4;
int left_remain = pad_l - cnt_l * 4;
bool flag_ext_l = left_remain > 0;
int left_sl = 4 - left_remain;
uint32x4_t vmask_padl;
bool flag_mask_l = false;
if (flag_ext_l) {
if (valid_w < 3) {
flag_mask_l = true;
vmask_padl = vcltq_s32(vld1q_s32(mask), vdupq_n_s32(valid_w));
}
valid_w -= left_sl;
valid_w = valid_w > 0 ? valid_w : 0;
}
int cnt_valid = valid_w / 4;
int valid_sl = valid_w - cnt_valid * 4;
bool flag_mask_valid = valid_sl > 0;
uint32x4_t vmask_valid;
if (flag_mask_valid) {
vmask_valid = vcltq_s32(vld1q_s32(mask), vdupq_n_s32(valid_sl));
pad_r -= 4 - valid_sl;
pad_r = pad_r > 0 ? pad_r : 0;
}
int size_c = width * height;
for (int h = hs; h < he; ++h) {
auto ptr_c0 = din + cs * size_c + h * width;
auto ptr_c1 = ptr_c0 + size_c;
auto ptr_c2 = ptr_c1 + size_c;
auto ptr_c3 = ptr_c2 + size_c;
if (h < 0 || h >= height) {
memset(dout, 0, sizeof(float) * size_w * 4);
dout += size_w * 4;
continue;
} else if (cs + 4 > channel) {
switch (cs + 4 - channel) {
case 3:
ptr_c1 = zero_ptr;
case 2:
ptr_c2 = zero_ptr;
case 1:
ptr_c3 = zero_ptr;
default:
break;
}
}
/// left padding
if (cnt_l > 0) {
memset(dout, 0, sizeof(float) * 16 * cnt_l);
dout += 16 * cnt_l;
}
/// left mask
if (flag_ext_l) {
float32x4_t vc0 = vld1q_f32(ptr_c0);
float32x4_t vc1 = vld1q_f32(ptr_c1);
float32x4_t vc2 = vld1q_f32(ptr_c2);
float32x4_t vc3 = vld1q_f32(ptr_c3);
if (flag_mask_l) {
vc0 = vbslq_f32(vmask_padl, vc0, vzero);
vc1 = vbslq_f32(vmask_padl, vc1, vzero);
vc2 = vbslq_f32(vmask_padl, vc2, vzero);
vc3 = vbslq_f32(vmask_padl, vc3, vzero);
}
switch (left_sl) {
case 1:
vc0 = vextq_f32(vzero, vc0, 1);
vc1 = vextq_f32(vzero, vc1, 1);
vc2 = vextq_f32(vzero, vc2, 1);
vc3 = vextq_f32(vzero, vc3, 1);
break;
case 2:
vc0 = vextq_f32(vzero, vc0, 2);
vc1 = vextq_f32(vzero, vc1, 2);
vc2 = vextq_f32(vzero, vc2, 2);
vc3 = vextq_f32(vzero, vc3, 2);
break;
case 3:
vc0 = vextq_f32(vzero, vc0, 3);
vc1 = vextq_f32(vzero, vc1, 3);
vc2 = vextq_f32(vzero, vc2, 3);
vc3 = vextq_f32(vzero, vc3, 3);
break;
default:
break;
}
transpose_4x4(vc0, vc1, vc2, vc3, dout);
dout += 16;
ptr_c0 += left_sl;
ptr_c1 += left_sl;
ptr_c2 += left_sl;
ptr_c3 += left_sl;
}
/// valid
for (int i = 0; i < cnt_valid; ++i) {
float32x4_t vc0 = vld1q_f32(ptr_c0);
float32x4_t vc1 = vld1q_f32(ptr_c1);
float32x4_t vc2 = vld1q_f32(ptr_c2);
float32x4_t vc3 = vld1q_f32(ptr_c3);
transpose_4x4(vc0, vc1, vc2, vc3, dout);
dout += 16;
ptr_c0 += 4;
ptr_c1 += 4;
ptr_c2 += 4;
ptr_c3 += 4;
}
if (flag_mask_valid) {
float32x4_t vc0 = vld1q_f32(ptr_c0);
float32x4_t vc1 = vld1q_f32(ptr_c1);
float32x4_t vc2 = vld1q_f32(ptr_c2);
float32x4_t vc3 = vld1q_f32(ptr_c3);
vc0 = vbslq_f32(vmask_valid, vc0, vzero);
vc1 = vbslq_f32(vmask_valid, vc1, vzero);
vc2 = vbslq_f32(vmask_valid, vc2, vzero);
vc3 = vbslq_f32(vmask_valid, vc3, vzero);
transpose_4x4(vc0, vc1, vc2, vc3, dout);
dout += 16;
}
/// right padding
if (pad_r > 0) {
memset(dout, 0, sizeof(float) * 4 * pad_r);
dout += 4 * pad_r;
}
}
}
inline void prepack_input_nxw_c8_int8(const int8_t* din,
int8_t* dout,
int cs,
int ce,
int hs,
int he,
int ws,
int we,
int channel,
int width,
int height) {
int n = he - hs;
if (n <= 0) {
LOG(FATAL) << "prepack_input_nxw_c8 input height must > 0";
return;
}
int w0 = ws < 0 ? 0 : ws;
int w1 = we > width ? width : we;
int size_w = we - ws;
int size_channel_in = width * height;
int size_out_row = size_w * 8;
int valid_w = w1 - w0;
size_t valid_w_byte = valid_w * sizeof(int8_t);
auto ptr_c = static_cast<int8_t*>(TargetMalloc(TARGET(kARM), 8 * size_w));
int8_t* ptr_r[8];
int8_t* ptr_c_ori[8] = {ptr_c,
ptr_c + size_w,
ptr_c + 2 * size_w,
ptr_c + 3 * size_w,
ptr_c + 4 * size_w,
ptr_c + 5 * size_w,
ptr_c + 6 * size_w,
ptr_c + 7 * size_w};
int8_t zero_ptr[size_w * 2]; // NOLINT
memset(zero_ptr, 0, size_w * 2);
int loop = size_w / 8;
int remain = size_w - loop * 8;
for (int c = cs; c < ce; c += 8) {
auto din_c = din + c * size_channel_in;
for (int j = 0; j < 8; ++j) {
ptr_r[j] = ptr_c_ori[j];
}
//! valid channel
if (c + 8 > channel) {
switch (c + 8 - channel) {
case 7:
ptr_r[1] = zero_ptr;
case 6:
ptr_r[2] = zero_ptr;
case 5:
ptr_r[3] = zero_ptr;
case 4:
ptr_r[4] = zero_ptr;
case 3:
ptr_r[5] = zero_ptr;
case 2:
ptr_r[6] = zero_ptr;
case 1:
ptr_r[7] = zero_ptr;
default:
break;
}
}
//! valid height
int j = 0;
for (int i = hs; i < he; i++) {
auto din_r = din_c + i * width;
for (int k = 0; k < 8; ++k) {
if (ptr_r[k] != zero_ptr) {
if (i < 0 || i >= height) {
ptr_r[k] = zero_ptr + size_w;
} else {
ptr_r[k] = ptr_c_ori[k];
auto ptr = ptr_r[k];
for (int w = ws; w < w0; ++w) {
*(ptr++) = 0;
}
memcpy(ptr, din_r + k * size_channel_in, valid_w_byte);
ptr += valid_w;
for (int w = w1; w < we; ++w) {
*(ptr++) = 0;
}
}
}
}
int cnt = loop;
int8_t* inr0 = ptr_r[0];
int8_t* inr1 = ptr_r[1];
int8_t* inr2 = ptr_r[2];
int8_t* inr3 = ptr_r[3];
int8_t* inr4 = ptr_r[4];
int8_t* inr5 = ptr_r[5];
int8_t* inr6 = ptr_r[6];
int8_t* inr7 = ptr_r[7];
auto ptr_out = dout + j * size_out_row;
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(
/* main loop */
"1:\n"
"ldr d0, [%[r0]], #8\n"
"ldr d1, [%[r1]], #8\n"
"ldr d2, [%[r2]], #8\n"
"ldr d3, [%[r3]], #8\n"
"ldr d4, [%[r4]], #8\n"
"ldr d5, [%[r5]], #8\n"
"ldr d6, [%[r6]], #8\n"
"ldr d7, [%[r7]], #8\n"
"trn1 v8.8b, v0.8b, v1.8b\n"
"trn2 v9.8b, v0.8b, v1.8b\n"
"trn1 v10.8b, v2.8b, v3.8b\n"
"trn2 v11.8b, v2.8b, v3.8b\n"
"trn1 v12.8b, v4.8b, v5.8b\n"
"trn2 v13.8b, v4.8b, v5.8b\n"
"trn1 v14.8b, v6.8b, v7.8b\n"
"trn2 v15.8b, v6.8b, v7.8b\n"
"trn1 v0.4h, v8.4h, v10.4h\n"
"trn2 v1.4h, v8.4h, v10.4h\n"
"trn1 v2.4h, v9.4h, v11.4h\n"
"trn2 v3.4h, v9.4h, v11.4h\n"
"trn1 v4.4h, v12.4h, v14.4h\n"
"trn2 v5.4h, v12.4h, v14.4h\n"
"trn1 v6.4h, v13.4h, v15.4h\n"
"trn2 v7.4h, v13.4h, v15.4h\n"
"trn1 v8.2s, v0.2s, v4.2s\n"
"trn1 v9.2s, v2.2s, v6.2s\n"
"trn1 v10.2s, v1.2s, v5.2s\n"
"trn1 v11.2s, v3.2s, v7.2s\n"
"stp d8, d9, [%[ptr_out]], #16\n"
"trn2 v12.2s, v0.2s, v4.2s\n"
"trn2 v13.2s, v2.2s, v6.2s\n"
"stp d10, d11, [%[ptr_out]], #16\n"
"trn2 v14.2s, v1.2s, v5.2s\n"
"trn2 v15.2s, v3.2s, v7.2s\n"
"subs %w[cnt], %w[cnt], #1\n"
"stp d12, d13, [%[ptr_out]], #16\n"
"stp d14, d15, [%[ptr_out]], #16\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[r5] "+r"(inr5),
[r6] "+r"(inr6),
[r7] "+r"(inr7),
[ptr_out] "+r"(ptr_out)
:
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(
/* main loop */
"1:\n"
"vld1.32 {d0}, [%[r0]]!\n"
"vld1.32 {d1}, [%[r1]]!\n"
"vld1.32 {d2}, [%[r2]]!\n"
"vld1.32 {d3}, [%[r3]]!\n"
"vld1.32 {d4}, [%[r4]]!\n"
"vld1.32 {d5}, [%[r5]]!\n"
"vld1.32 {d6}, [%[r6]]!\n"
"vld1.32 {d7}, [%[r7]]!\n"
"vtrn.8 d0, d1\n"
"vtrn.8 d2, d3\n"
"vtrn.8 d4, d5\n"
"vtrn.8 d6, d7\n"
"vtrn.16 d0, d2\n"
"vtrn.16 d1, d3\n"
"vtrn.16 d4, d6\n"
"vtrn.16 d5, d7\n"
"vtrn.32 d0, d4\n"
"vtrn.32 d2, d6\n"
"vtrn.32 d1, d5\n"
"vtrn.32 d3, d7\n"
"subs %[cnt], #1\n"
"vst1.32 {d0-d3}, [%[ptr_out]]!\n"
"vst1.32 {d4-d7}, [%[ptr_out]]!\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[r5] "+r"(inr5),
[r6] "+r"(inr6),
[r7] "+r"(inr7),
[ptr_out] "+r"(ptr_out)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
#endif // aarch64
}
for (int k = 0; k < remain; ++k) {
ptr_out[0] = *(inr0++);
ptr_out[1] = *(inr1++);
ptr_out[2] = *(inr2++);
ptr_out[3] = *(inr3++);
ptr_out[4] = *(inr4++);
ptr_out[5] = *(inr5++);
ptr_out[6] = *(inr6++);
ptr_out[7] = *(inr7++);
ptr_out += 8;
}
j++;
}
}
TargetFree(TARGET(kARM), ptr_c);
}
/*wirte result in outputs /*wirte result in outputs
* input din: [n, c, h, w], output dout: [n, c, h, w] * input din: [n, c, h, w], output dout: [n, c, h, w]
*/ */
...@@ -1195,1570 +1658,1144 @@ inline bool write_to_output_c8_fp32(const float* din, ...@@ -1195,1570 +1658,1144 @@ inline bool write_to_output_c8_fp32(const float* din,
return true; return true;
} }
/*wirte result in outputs template <typename Dtype>
* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] inline void int32_nchwc4_kernel(Dtype*& dout0, // NOLINT
*/ Dtype*& dout1, // NOLINT
inline bool write_to_output_c4_int32(const int* din, Dtype*& dout2, // NOLINT
int* dout, Dtype*& dout3, // NOLINT
int ch_n, const int32_t*& din, // NOLINT
int hei_n, int cnt,
int cs, float32x4_t scale,
int ce, float32x4_t bias,
int hs, bool is_relu);
int he,
int ws, #ifdef __aarch64__
int we, #define NCHWC4_TRANS_INT32 \
int channel, "ldp q0, q1, [%[ptr_din]], #32\n" \
int height, "ldp q2, q3, [%[ptr_din]], #32\n" \
int width, "movi v20.4s, #0\n" \
bool flag_relu, "1:\n" \
int* trash_ptr) { "trn1 v8.4s, v0.4s, v1.4s\n" \
if (ch_n != 4 || hei_n <= 0) { "trn2 v9.4s, v0.4s, v1.4s\n" \
LOG(ERROR) << "ch_n must be equal 4 and hei_n is more than zero"; "ldp q0, q1, [%[ptr_din]], #32\n" \
return false; "trn1 v10.4s, v2.4s, v3.4s\n" \
"trn2 v11.4s, v2.4s, v3.4s\n" \
"ldp q2, q3, [%[ptr_din]], #32\n" \
"trn1 v16.2d, v8.2d, v10.2d\n" \
"trn2 v17.2d, v8.2d, v10.2d\n" \
"trn1 v18.2d, v9.2d, v11.2d\n" \
"trn2 v19.2d, v9.2d, v11.2d\n" /* int32 --> fp32 */ \
"scvtf v4.4s, v16.4s\n" \
"scvtf v5.4s, v17.4s\n" \
"scvtf v6.4s, v18.4s\n" \
"scvtf v7.4s, v19.4s\n" /* add bias */ \
"dup v16.4s, %[bias].s[0]\n" \
"dup v17.4s, %[bias].s[2]\n" \
"dup v18.4s, %[bias].s[1]\n" \
"dup v19.4s, %[bias].s[3]\n" /* mul scale */ \
"fmla v16.4s, v4.4s, %[scale].s[0]\n" \
"fmla v17.4s, v5.4s, %[scale].s[2]\n" \
"fmla v18.4s, v6.4s, %[scale].s[1]\n" \
"fmla v19.4s, v7.4s, %[scale].s[3]\n" /* relu */ \
"cbz %w[relu], 2f\n" \
"fmax v16.4s, v16.4s, v20.4s \n" \
"fmax v17.4s, v17.4s, v20.4s \n" \
"fmax v18.4s, v18.4s, v20.4s \n" \
"fmax v19.4s, v19.4s, v20.4s \n" \
"2:\n"
#else
#define NCHWC4_TRANS_INT32 \
"vld1.32 {d4-d7}, [%[ptr_din]]!\n" \
"vld1.32 {d8-d11}, [%[ptr_din]]!\n" \
"vmov.u32 q15, #0\n" \
"1:\n" /* transpose */ \
"vtrn.32 q2, q3\n" \
"vtrn.32 q4, q5\n" \
"vswp.32 d5, d8\n" \
"vswp.32 d7, d10\n" /* int32-> fp32 */ \
"vcvt.f32.s32 q6, q2\n" \
"vcvt.f32.s32 q7, q3\n" \
"vcvt.f32.s32 q8, q4\n" \
"vcvt.f32.s32 q9, q5\n" /* add bias */ \
"vdup.32 q10, %e[bias][0]\n" \
"vdup.32 q11, %e[bias][1]\n" \
"vdup.32 q12, %f[bias][0]\n" \
"vdup.32 q13, %f[bias][1]\n" /* mul scale */ \
"vmla.f32 q10, q6, %e[scale][0]\n" \
"vmla.f32 q11, q7, %e[scale][1]\n" \
"vmla.f32 q12, q8, %f[scale][0]\n" \
"vmla.f32 q13, q9, %f[scale][1]\n" /* relu */ \
"cmp %[relu], #0\n" \
"beq 2f\n" \
"vmax.f32 q10, q10, q15\n" \
"vmax.f32 q11, q11, q15\n" \
"vmax.f32 q12, q12, q15\n" \
"vmax.f32 q13, q13, q15\n" \
"2:\n"
#endif
template <>
inline void int32_nchwc4_kernel(float*& dout0, // NOLINT
float*& dout1, // NOLINT
float*& dout2, // NOLINT
float*& dout3, // NOLINT
const int32_t*& din, // NOLINT
int cnt,
float32x4_t scale,
float32x4_t bias,
bool is_relu) {
#ifdef __aarch64__
asm volatile(NCHWC4_TRANS_INT32
"subs %w[cnt], %w[cnt], #1\n"
/* store result */
"str q16, [%[doutc0r0]], #16\n"
"str q17, [%[doutc2r0]], #16\n"
"str q18, [%[doutc1r0]], #16\n"
"str q19, [%[doutc3r0]], #16\n"
"bne 1b\n"
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v31");
#else
asm volatile(NCHWC4_TRANS_INT32
"subs %[cnt], %[cnt], #1\n"
/* store result */
"vld1.32 {d4-d7}, [%[ptr_din]]!\n"
"vst1.32 {d20-d21}, [%[doutc0r0]]!\n"
"vst1.32 {d22-d23}, [%[doutc1r0]]!\n"
"vld1.32 {d8-d11}, [%[ptr_din]]!\n"
"vst1.32 {d24-d25}, [%[doutc2r0]]!\n"
"vst1.32 {d26-d27}, [%[doutc3r0]]!\n"
"bne 1b\n"
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu)
: "cc",
"memory",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
template <>
inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT
int8_t*& dout1, // NOLINT
int8_t*& dout2, // NOLINT
int8_t*& dout3, // NOLINT
const int32_t*& din, // NOLINT
int cnt,
float32x4_t scale,
float32x4_t bias,
bool is_relu) {
#ifdef __aarch64__
asm volatile(NCHWC4_TRANS_INT32
"subs %w[cnt], %w[cnt], #1\n"
/* fp32-int32 */
"fcvtas v4.4s, v16.4s\n"
"fcvtas v5.4s, v18.4s\n"
"fcvtas v6.4s, v17.4s\n"
"fcvtas v7.4s, v19.4s\n"
/* int32-int16 */
"sqxtn v8.4h, v4.4s\n"
"sqxtn v9.4h, v5.4s\n"
"sqxtn v10.4h, v6.4s\n"
"sqxtn v11.4h, v7.4s\n"
/* int16-int8 */
"sqxtn v16.8b, v8.8h\n"
"sqxtn v17.8b, v9.8h\n"
"sqxtn v18.8b, v10.8h\n"
"sqxtn v19.8b, v11.8h\n"
/* store result */
"str s16, [%[doutc0r0]], #4\n"
"str s17, [%[doutc1r0]], #4\n"
"str s18, [%[doutc2r0]], #4\n"
"str s19, [%[doutc3r0]], #4\n"
"bne 1b\n"
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v31");
#else
asm volatile(NCHWC4_TRANS_INT32
/* set 0.5 offset */
"vmov.f32 q2, #0.5\n"
"vmov.f32 q14, #-0.5\n"
"vand.i32 q3, q2, q2 @ set offset, 0.5\n"
"vand.i32 q4, q2, q2 @ set offset, 0.5\n"
"vand.i32 q5, q2, q2 @ set offset, 0.5\n"
"vcgt.f32 q6, q10, q15 @ get mask > 0, in0\n"
"vcgt.f32 q7, q11, q15 @ get mask > 0, in1\n"
"vcgt.f32 q8, q12, q15 @ get mask > 0, in2\n"
"vcgt.f32 q9, q13, q15 @ get mask > 0, in3\n"
/* set 0.5 offset */
"vbif.f32 q2, q14, q6 @ get right offset\n"
"vbif.f32 q3, q14, q7 @ get right offset\n"
"vbif.f32 q4, q14, q8 @ get right offset\n"
"vbif.f32 q5, q14, q9 @ get right offset\n"
/* add offset */
"vadd.f32 q10, q2, q10\n"
"vadd.f32 q11, q3, q11\n"
"vadd.f32 q12, q4, q12\n"
"vadd.f32 q13, q5, q13\n"
/* fp32 to int32 */
"vcvt.s32.f32 q6, q10 @ cvt to int32\n"
"vcvt.s32.f32 q7, q11 @ cvt to int32\n"
"vcvt.s32.f32 q8, q12 @ cvt to int32\n"
"vcvt.s32.f32 q9, q13 @ cvt to int32\n"
/* int32 to int16 */
"vqmovn.s32 d20, q6 @ cnt to int16\n"
"vqmovn.s32 d22, q7 @ cnt to int16\n"
"vqmovn.s32 d24, q8 @ cnt to int16\n"
"vqmovn.s32 d26, q9 @ cnt to int16\n"
/* int16 to int8 */
"vqmovn.s16 d12, q10 @ cnt to int8\n"
"vqmovn.s16 d13, q11 @ cnt to int8\n"
"vqmovn.s16 d14, q12 @ cnt to int8\n"
"vqmovn.s16 d15, q13 @ cnt to int8\n"
"subs %[cnt], %[cnt], #1\n"
/* store */
"vld1.32 {d4-d7}, [%[ptr_din]]!\n"
"vst1.32 {d12[0]}, [%[doutc0r0]]!\n"
"vst1.32 {d13[0]}, [%[doutc1r0]]!\n"
"vld1.32 {d8-d11}, [%[ptr_din]]!\n"
"vst1.32 {d14[0]}, [%[doutc2r0]]!\n"
"vst1.32 {d15[0]}, [%[doutc3r0]]!\n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu)
: "cc",
"memory",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
template <>
inline void int32_nchwc4_kernel(int32_t*& dout0, // NOLINT
int32_t*& dout1, // NOLINT
int32_t*& dout2, // NOLINT
int32_t*& dout3, // NOLINT
const int32_t*& din, // NOLINT
int cnt,
float32x4_t scale,
float32x4_t bias,
bool is_relu) {
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"cbz %w[relu], 2f\n"
"smax v16.4s, v16.4s, v20.4s \n" /* relu */
"smax v17.4s, v17.4s, v20.4s \n" /* relu */
"smax v18.4s, v18.4s, v20.4s \n" /* relu */
"smax v19.4s, v19.4s, v20.4s \n" /* relu */
"2:\n"
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vmov.u32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vtrn.32 q0, q1 @ trans q0, q1 \n"
"vtrn.32 q2, q3 @ trans q2, q3 \n"
"vswp.32 d1, d4 @ swap d1, d4 \n"
"vswp.32 d3, d6 @ swap d3, d6 \n"
"cmp %[relu], #0\n"
"bne 2f\n"
"vmax.s32 q0, q0, q15 @ relu\n"
"vmax.s32 q1, q1, q15 @ relu\n"
"vmax.s32 q2, q2, q15 @ relu\n"
"vmax.s32 q3, q3, q15 @ relu\n"
"2:\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [relu] "r"(is_relu)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q15");
#endif
}
template <typename Dtype>
inline Dtype cvt_kernel(int din, float scale, float bias, bool flag_relu);
template <>
inline float cvt_kernel(int din, float scale, float bias, bool flag_relu) {
if (flag_relu) {
return LITEMAX(din * scale + bias, 0);
} }
int size_c_out = width * height; return din * scale + bias;
}
int* doutc0r0 = dout + cs * size_c_out + hs * width + ws; template <>
int* doutc1r0 = doutc0r0 + size_c_out; inline int8_t cvt_kernel(int din, float scale, float bias, bool flag_relu) {
int* doutc2r0 = doutc1r0 + size_c_out; if (flag_relu) {
int* doutc3r0 = doutc2r0 + size_c_out; return saturate_cast<int8_t>(round(LITEMAX(din * scale + bias, 0)));
}
return saturate_cast<int8_t>(round(din * scale + bias));
}
const int* ptr_din = din; template <>
inline int32_t cvt_kernel(int din, float scale, float bias, bool flag_relu) {
if (flag_relu) {
return LITEMAX(din, 0);
}
return din;
}
template <typename Dtype>
inline void write_int32_nchwc4_to_nchw(const int* din,
Dtype* dout,
int cs,
int ce,
int hs,
int he,
int ws,
int we,
int channel,
int height,
int width,
bool flag_relu,
float* bias,
bool flag_bias,
Dtype* trash_ptr,
const float* scale) {
int size_c_out = width * height;
Dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
Dtype* doutc1r0 = doutc0r0 + size_c_out;
Dtype* doutc2r0 = doutc1r0 + size_c_out;
Dtype* doutc3r0 = doutc2r0 + size_c_out;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int valid_w = we - ws; int valid_w = we - ws;
int cnt = valid_w / 4; int cnt = valid_w / 4;
float32x4_t w_scale = vld1q_f32(scale);
float32x4_t w_bias = flag_bias ? vld1q_f32(bias) : vdupq_n_f32(0.f);
if (we > width) { if (we > width) {
cnt--; cnt--;
} }
if (flag_relu) { for (int i = 0; i < size_h; i++) {
for (int i = 0; i < size_h; i++) { int size_w = i * width;
int size_w = i * width; Dtype* doutc0_ptr = doutc0r0 + size_w;
int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; Dtype* doutc1_ptr = doutc1r0 + size_w;
int* doutc1_ptr = doutc1r0 + size_w; Dtype* doutc2_ptr = doutc2r0 + size_w;
int* doutc2_ptr = doutc2r0 + size_w; Dtype* doutc3_ptr = doutc3r0 + size_w;
int* doutc3_ptr = doutc3r0 + size_w; if (ce > channel) {
if (ce > channel) { switch (ce - channel) {
switch (ce - channel) { case 3:
case 3: doutc1_ptr = trash_ptr;
doutc1_ptr = trash_ptr; case 2:
case 2: doutc2_ptr = trash_ptr;
doutc2_ptr = trash_ptr; case 1:
case 1: doutc3_ptr = trash_ptr;
doutc3_ptr = trash_ptr; default:
default: break;
break;
}
} }
ptr_din = din + i * valid_w * ch_n; }
const int* din_hei_ptr = ptr_din; int index = i * valid_w * 4;
if (cnt > 0) { const int* din_hei_ptr = din + index;
int cnt_loop = cnt; if (cnt > 0) {
int32_nchwc4_kernel<Dtype>(doutc0_ptr,
doutc1_ptr,
doutc2_ptr,
doutc3_ptr,
din_hei_ptr,
cnt,
w_scale,
w_bias,
flag_relu);
}
if (we > width) {
int offset = 16 * (valid_w / 4 - 1);
din_hei_ptr = din + index + offset;
int j = we - 4;
for (; j < width; ++j) {
*(doutc0_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], bias[0], flag_relu);
*(doutc1_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[1], scale[1], bias[1], flag_relu);
*(doutc2_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[2], scale[2], bias[2], flag_relu);
*(doutc3_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[3], scale[3], bias[3], flag_relu);
din_hei_ptr += 4;
}
}
}
}
template <typename Dtype>
inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT
Dtype*& dout1, // NOLINT
Dtype*& dout2, // NOLINT
Dtype*& dout3, // NOLINT
Dtype*& dout4, // NOLINT
Dtype*& dout5, // NOLINT
Dtype*& dout6, // NOLINT
Dtype*& dout7, // NOLINT
const int32_t*& din, // NOLINT
int cnt,
float32x4_t scale0,
float32x4_t scale1,
float32x4_t bias0,
float32x4_t bias1,
bool is_relu);
// clang-format off
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( #define INT32_NCHWC8_TO_NCHW_FP32 \
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ "ldp q0, q1, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ "ldp q2, q3, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \
"movi v20.4s, #0 \n" /* for relu */ "ldp q4, q5, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \
"1: \n" /* main loop*/ "ldp q6, q7, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ "movi v31.4s, #0\n" /* main loop*/ \
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ "1:\n" \
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ "trn1 v8.4s, v0.4s, v2.4s\n" /* trans q0, q1*/ \
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ "trn2 v9.4s, v0.4s, v2.4s\n" /* trans q0, q1*/ \
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ "trn1 v10.4s, v1.4s, v3.4s\n" /* trans q2, q3*/ \
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ "trn2 v11.4s, v1.4s, v3.4s\n" /* trans q2, q3*/ \
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ "ldp q0, q1, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ "trn1 v12.4s, v4.4s, v6.4s\n" /* trans q0, q1*/ \
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ "trn2 v13.4s, v4.4s, v6.4s\n" /* trans q0, q1*/ \
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ "trn1 v14.4s, v5.4s, v7.4s\n" /* trans q2, q3*/ \
"smax v16.4s, v16.4s, v20.4s \n" /* relu */ "trn2 v15.4s, v5.4s, v7.4s\n" /* trans q2, q3*/ \
"smax v17.4s, v17.4s, v20.4s \n" /* relu */ "ldp q2, q3, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \
"smax v18.4s, v18.4s, v20.4s \n" /* relu */ "trn1 v16.2d, v8.2d, v12.2d\n" /* trans q8, q10 00 01 02 03*/ \
"smax v19.4s, v19.4s, v20.4s \n" /* relu */ "trn2 v17.2d, v8.2d, v12.2d\n" /* trans q8, q10 20 21 22 23*/ \
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ "trn1 v18.2d, v9.2d, v13.2d\n" /* trans q9, q11 10 11 12 13*/ \
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ "trn2 v19.2d, v9.2d, v13.2d\n" /* trans q9, q11 30 31 32 33*/ \
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ "ldp q4, q5, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ "trn1 v8.2d, v10.2d, v14.2d\n" /* trans q8, q10 40 41 42 43*/ \
"trn2 v9.2d, v10.2d, v14.2d\n" /* trans q8, q10 60 61 62 63*/ \
"trn1 v12.2d, v11.2d, v15.2d\n" /* trans q9, q11 50 51 52 53*/ \
"trn2 v13.2d, v11.2d, v15.2d\n" /* trans q9, q11 70 71 72 73*/ \
"ldp q6, q7, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \
/* int32->fp32 */ \
"scvtf v10.4s, v16.4s\n" \
"scvtf v11.4s, v17.4s\n" \
"scvtf v14.4s, v18.4s\n" \
"scvtf v15.4s, v19.4s\n" \
/* add bias */ \
"dup v16.4s, %[bias0].s[0]\n" \
"dup v17.4s, %[bias0].s[2]\n" \
"dup v18.4s, %[bias0].s[1]\n" \
"dup v19.4s, %[bias0].s[3]\n" \
/* mul scale */ \
"fmla v16.4s, v10.4s, %[scale0].s[0]\n" \
"fmla v17.4s, v11.4s, %[scale0].s[2]\n" \
"fmla v18.4s, v14.4s, %[scale0].s[1]\n" \
"fmla v19.4s, v15.4s, %[scale0].s[3]\n" \
"scvtf v10.4s, v8.4s\n" \
"scvtf v11.4s, v9.4s\n" \
"scvtf v14.4s, v12.4s\n" \
"scvtf v15.4s, v13.4s\n" \
/* add bias */ \
"dup v8.4s, %[bias1].s[0]\n" \
"dup v9.4s, %[bias1].s[2]\n" \
"dup v12.4s, %[bias1].s[1]\n" \
"dup v13.4s, %[bias1].s[3]\n" \
/* mul scale */ \
"fmla v8.4s, v10.4s, %[scale1].s[0]\n" \
"fmla v9.4s, v11.4s, %[scale1].s[2]\n" \
"fmla v12.4s, v14.4s, %[scale1].s[1]\n" \
"fmla v13.4s, v15.4s, %[scale1].s[3]\n" \
/* relu */ \
"cbz %w[relu], 2f\n" \
"fmax v16.4s, v16.4s, v31.4s\n" /*relu*/ \
"fmax v17.4s, v17.4s, v31.4s\n" /*relu*/ \
"fmax v18.4s, v18.4s, v31.4s\n" /*relu*/ \
"fmax v19.4s, v19.4s, v31.4s\n" /*relu*/ \
"fmax v8.4s, v8.4s, v31.4s\n" /*relu*/ \
"fmax v9.4s, v9.4s, v31.4s\n" /*relu*/ \
"fmax v12.4s, v12.4s, v31.4s\n" /*relu*/ \
"fmax v13.4s, v13.4s, v31.4s\n" /*relu*/ \
"2:\n"
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ #else
"bne 1b \n" /* jump to main loop*/ #define INT32_NCHWC8_TO_NCHW_FP32 \
"1: @ main loop\n" \
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" \
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" \
/* int32-> fp32 */ \
"vcvt.f32.s32 q8, q0\n" \
"vcvt.f32.s32 q9, q1\n" \
"vcvt.f32.s32 q10, q2\n" \
"vcvt.f32.s32 q11, q3\n" \
"vand.32 q0, %q[bias0], %q[bias0]\n" \
"vand.32 q1, %q[bias1], %q[bias1]\n" \
"vand.32 q2, %q[bias0], %q[bias0]\n" \
"vand.32 q3, %q[bias1], %q[bias1]\n" \
/* mul scale */ \
"vmla.f32 q0, q8, %q[scale0]\n" \
"vmla.f32 q1, q9, %q[scale1]\n" \
"vmla.f32 q2, q10, %q[scale0]\n" \
"vmla.f32 q3, q11, %q[scale1]\n" \
/* int32-> fp32 */ \
"vcvt.f32.s32 q8, q4\n" \
"vcvt.f32.s32 q9, q5\n" \
"vcvt.f32.s32 q10, q6\n" \
"vcvt.f32.s32 q11, q7\n" \
"vand.32 q4, %q[bias0], %q[bias0]\n" \
"vand.32 q5, %q[bias1], %q[bias1]\n" \
"vand.32 q6, %q[bias0], %q[bias0]\n" \
"vand.32 q7, %q[bias1], %q[bias1]\n" \
/* mul scale */ \
"vmla.f32 q4, q8, %q[scale0]\n" \
"vmla.f32 q5, q9, %q[scale1]\n" \
"vmla.f32 q6, q10, %q[scale0]\n" \
"vmla.f32 q7, q11, %q[scale1]\n" \
/* transpose */ \
"vtrn.32 q0, q2\n" \
"vtrn.32 q1, q3\n" \
"vtrn.32 q4, q6\n" \
"vtrn.32 q5, q7\n" \
"vswp d1, d8\n" /* q0: a0-a3, q4: c0-c3 */ \
"vswp d5, d12\n" /* q2: b0-b3, q6: d0-d3 */ \
"vswp d3, d10\n" /* q1: e0-e3, q5: g0-g3 */ \
"vswp d7, d14\n" /* q3: f0-f3, q7: h0-h3 */ \
/* relu */ \
"vmov.i32 q8, #0\n" \
"cmp %[relu], #0\n" \
"beq 2f\n" \
"vmax.f32 q0, q0, q8\n" /*relu*/ \
"vmax.f32 q2, q2, q8\n" /*relu*/ \
"vmax.f32 q4, q4, q8\n" /*relu*/ \
"vmax.f32 q6, q6, q8\n" /*relu*/ \
"vmax.f32 q1, q1, q8\n" /*relu*/ \
"vmax.f32 q3, q3, q8\n" /*relu*/ \
"vmax.f32 q5, q5, q8\n" /*relu*/ \
"vmax.f32 q7, q7, q8\n" /*relu*/ \
"2:\n"
: [doutc0r0] "+r"(doutc0_ptr), #endif
[doutc1r0] "+r"(doutc1_ptr), // clang-format on
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr), template <>
[cnt] "+r"(cnt_loop), inline void int32_nchwc8_kernel(float*& dout0, // NOLINT
[ptr_din] "+r"(din_hei_ptr) float*& dout1, // NOLINT
: float*& dout2, // NOLINT
: "v0", float*& dout3, // NOLINT
"v1", float*& dout4, // NOLINT
"v2", float*& dout5, // NOLINT
"v3", float*& dout6, // NOLINT
"v4", float*& dout7, // NOLINT
"v5", const int32_t*& din, // NOLINT
"v6", int cnt,
"v7", float32x4_t scale0,
"v8", float32x4_t scale1,
"v9", float32x4_t bias0,
"v10", float32x4_t bias1,
"v11", bool is_relu) {
"v12", #ifdef __aarch64__
"v13", asm volatile(INT32_NCHWC8_TO_NCHW_FP32
"v14", "subs %w[cnt], %w[cnt], #1\n" /* loop count -1*/
"v15", "str q16, [%[doutc0r0]], #16\n" /* store c0r0*/
"v16", "str q17, [%[doutc2r0]], #16\n" /* store c2r0*/
"v17", "str q18, [%[doutc1r0]], #16\n" /* store c1r0*/
"v18", "str q19, [%[doutc3r0]], #16\n" /* store c3r0*/
"v19", "str q8, [%[doutc4r0]], #16\n" /* store c4r0*/
"v20"); "str q9, [%[doutc6r0]], #16\n" /* store c6r0*/
"str q12, [%[doutc5r0]], #16\n" /* store c5r0*/
"str q13, [%[doutc7r0]], #16\n" /* store c7r0*/
"bne 1b\n" /* jump to main loop*/
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[doutc4r0] "+r"(dout4),
[doutc5r0] "+r"(dout5),
[doutc6r0] "+r"(dout6),
[doutc7r0] "+r"(dout7),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale0] "w"(scale0),
[scale1] "w"(scale1),
[bias0] "w"(bias0),
[bias1] "w"(bias1),
[relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v31");
#else #else
asm volatile( asm volatile(INT32_NCHWC8_TO_NCHW_FP32
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" "subs %[cnt], #1\n" /* loop count -1*/
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" "vst1.32 {d0-d1}, [%[doutc0r0]]!\n" /* store c0r0*/
"vmov.u32 q15, #0 @ dump zero\n" "vst1.32 {d4-d5}, [%[doutc1r0]]!\n" /* store c0r0*/
"1: @ main loop\n" "vst1.32 {d8-d9}, [%[doutc2r0]]!\n" /* store c0r0*/
"vtrn.32 q0, q1 @ trans q0, q1 \n" "vst1.32 {d12-d13}, [%[doutc3r0]]!\n" /* store c0r0*/
"vtrn.32 q2, q3 @ trans q2, q3 \n" "vst1.32 {d2-d3}, [%[doutc4r0]]!\n" /* store c0r0*/
"vswp.32 d1, d4 @ swap d1, d4 \n" "vst1.32 {d6-d7}, [%[doutc5r0]]!\n" /* store c0r0*/
"vswp.32 d3, d6 @ swap d3, d6 \n" "vst1.32 {d10-d11}, [%[doutc6r0]]!\n" /* store c0r0*/
"vst1.32 {d14-d15}, [%[doutc7r0]]!\n" /* store c0r0*/
"vmax.s32 q0, q0, q15 @ relu\n" "bne 1b\n" /* jump to main loop*/
"vmax.s32 q1, q1, q15 @ relu\n" : [doutc0r0] "+r"(dout0),
"vmax.s32 q2, q2, q15 @ relu\n" [doutc1r0] "+r"(dout1),
"vmax.s32 q3, q3, q15 @ relu\n" [doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[doutc4r0] "+r"(dout4),
[doutc5r0] "+r"(dout5),
[doutc6r0] "+r"(dout6),
[doutc7r0] "+r"(dout7),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale0] "w"(scale0),
[scale1] "w"(scale1),
[bias0] "w"(bias0),
[bias1] "w"(bias1),
[relu] "r"(is_relu)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11");
#endif
}
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" template <>
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT
"vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" int8_t*& dout1, // NOLINT
"vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" int8_t*& dout2, // NOLINT
int8_t*& dout3, // NOLINT
int8_t*& dout4, // NOLINT
int8_t*& dout5, // NOLINT
int8_t*& dout6, // NOLINT
int8_t*& dout7, // NOLINT
const int32_t*& din, // NOLINT
int cnt,
float32x4_t scale0,
float32x4_t scale1,
float32x4_t bias0,
float32x4_t bias1,
bool is_relu) {
#ifdef __aarch64__
asm volatile(INT32_NCHWC8_TO_NCHW_FP32 /* fp32-int32 */
"fcvtas v10.4s, v16.4s\n"
"fcvtas v11.4s, v17.4s\n"
"fcvtas v14.4s, v18.4s\n"
"fcvtas v15.4s, v19.4s\n"
"fcvtas v20.4s, v8.4s\n"
"fcvtas v21.4s, v9.4s\n"
"fcvtas v22.4s, v12.4s\n"
"fcvtas v23.4s, v13.4s\n"
/* int32-int16 */
"sqxtn v16.4h, v10.4s\n"
"sqxtn v17.4h, v11.4s\n"
"sqxtn v18.4h, v14.4s\n"
"sqxtn v19.4h, v15.4s\n"
"sqxtn v8.4h, v20.4s\n"
"sqxtn v9.4h, v21.4s\n"
"sqxtn v12.4h, v22.4s\n"
"sqxtn v13.4h, v23.4s\n"
/* int16-int8 */
"sqxtn v10.8b, v16.8h\n"
"sqxtn v11.8b, v17.8h\n"
"sqxtn v14.8b, v18.8h\n"
"sqxtn v15.8b, v19.8h\n"
"sqxtn v20.8b, v8.8h\n"
"sqxtn v21.8b, v9.8h\n"
"sqxtn v22.8b, v12.8h\n"
"sqxtn v23.8b, v13.8h\n"
"str s10, [%[doutc0r0]], #4 \n" /* store c0r0*/
"str s11, [%[doutc2r0]], #4 \n" /* store c2r0*/
"str s14, [%[doutc1r0]], #4 \n" /* store c1r0*/
"str s15, [%[doutc3r0]], #4 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str s20, [%[doutc4r0]], #4 \n" /* store c0r0*/
"str s21, [%[doutc6r0]], #4 \n" /* store c2r0*/
"str s22, [%[doutc5r0]], #4 \n" /* store c1r0*/
"str s23, [%[doutc7r0]], #4 \n" /* store c3r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[doutc4r0] "+r"(dout4),
[doutc5r0] "+r"(dout5),
[doutc6r0] "+r"(dout6),
[doutc7r0] "+r"(dout7),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale0] "w"(scale0),
[scale1] "w"(scale1),
[bias0] "w"(bias0),
[bias1] "w"(bias1),
[relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v31");
#else
asm volatile(INT32_NCHWC8_TO_NCHW_FP32 /* set +-0.5 offset */
"vmov.f32 q10, #-0.5\n"
"vmov.f32 q9, #0.5\n"
"vcgt.f32 q11, q0, q8 @ get mask > 0, in0\n"
"vbif.f32 q9, q10, q11 @ get right offset\n"
"vadd.f32 q0, q0, q9\n"
"vmov.f32 q9, #0.5\n"
"vcgt.f32 q11, q2, q8 @ get mask > 0, in0\n"
"vbif.f32 q9, q10, q11 @ get right offset\n"
"vadd.f32 q2, q2, q9\n"
"vmov.f32 q9, #0.5\n"
"vcgt.f32 q11, q4, q8 @ get mask > 0, in0\n"
"vbif.f32 q9, q10, q11 @ get right offset\n"
"vadd.f32 q4, q4, q9\n"
"vmov.f32 q9, #0.5\n"
"vcgt.f32 q11, q6, q8 @ get mask > 0, in0\n"
"vbif.f32 q9, q10, q11 @ get right offset\n"
"vadd.f32 q6, q6, q9\n"
"vmov.f32 q9, #0.5\n"
"vcgt.f32 q11, q1, q8 @ get mask > 0, in0\n"
"vbif.f32 q9, q10, q11 @ get right offset\n"
"vadd.f32 q1, q1, q9\n"
"vmov.f32 q9, #0.5\n"
"vcgt.f32 q11, q3, q8 @ get mask > 0, in0\n"
"vbif.f32 q9, q10, q11 @ get right offset\n"
"vadd.f32 q3, q3, q9\n"
"vmov.f32 q9, #0.5\n"
"vcgt.f32 q11, q5, q8 @ get mask > 0, in0\n"
"vbif.f32 q9, q10, q11 @ get right offset\n"
"vadd.f32 q5, q5, q9\n"
"vmov.f32 q9, #0.5\n"
"vcgt.f32 q11, q7, q8 @ get mask > 0, in0\n"
"vbif.f32 q9, q10, q11 @ get right offset\n"
"vadd.f32 q7, q7, q9\n"
/* fp32 to int32 */
"vcvt.s32.f32 q8, q0 @ cvt to int32\n"
"vcvt.s32.f32 q9, q2 @ cvt to int32\n"
"vcvt.s32.f32 q10, q4 @ cvt to int32\n"
"vcvt.s32.f32 q11, q6 @ cvt to int32\n"
/* int32 to int16 */
"vqmovn.s32 d0, q8 @ cnt to int16\n"
"vqmovn.s32 d4, q9 @ cnt to int16\n"
"vqmovn.s32 d8, q10 @ cnt to int16\n"
"vqmovn.s32 d12, q11 @ cnt to int16\n"
/* fp32 to int32 */
"vcvt.s32.f32 q8, q1 @ cvt to int32\n"
"vcvt.s32.f32 q9, q3 @ cvt to int32\n"
"vcvt.s32.f32 q10, q5 @ cvt to int32\n"
"vcvt.s32.f32 q11, q7 @ cvt to int32\n"
/* int32 to int16 */
"vqmovn.s32 d2, q8 @ cnt to int16\n"
"vqmovn.s32 d6, q9 @ cnt to int16\n"
"vqmovn.s32 d10, q10 @ cnt to int16\n"
"vqmovn.s32 d14, q11 @ cnt to int16\n"
/* int16 to int8 */
"vqmovn.s16 d16, q0 @ cnt to int8\n"
"vqmovn.s16 d17, q2 @ cnt to int8\n"
"vqmovn.s16 d18, q4 @ cnt to int8\n"
"vqmovn.s16 d19, q6 @ cnt to int8\n"
"vst1.32 {d16[0]}, [%[doutc0r0]]!\n"
"vqmovn.s16 d20, q1 @ cnt to int8\n"
"vst1.32 {d17[0]}, [%[doutc1r0]]!\n"
"vqmovn.s16 d21, q3 @ cnt to int8\n"
"vst1.32 {d18[0]}, [%[doutc2r0]]!\n"
"vqmovn.s16 d22, q5 @ cnt to int8\n"
"vst1.32 {d19[0]}, [%[doutc3r0]]!\n"
"vqmovn.s16 d23, q7 @ cnt to int8\n"
"subs %[cnt], #1\n"
"vst1.32 {d20[0]}, [%[doutc4r0]]!\n"
"vst1.32 {d21[0]}, [%[doutc5r0]]!\n"
"vst1.32 {d22[0]}, [%[doutc6r0]]!\n"
"vst1.32 {d23[0]}, [%[doutc7r0]]!\n"
"bne 1b\n" /* jump to main loop*/
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[doutc4r0] "+r"(dout4),
[doutc5r0] "+r"(dout5),
[doutc6r0] "+r"(dout6),
[doutc7r0] "+r"(dout7),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale0] "w"(scale0),
[scale1] "w"(scale1),
[bias0] "w"(bias0),
[bias1] "w"(bias1),
[relu] "r"(is_relu)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11");
#endif
}
"subs %[cnt], %[cnt], #1 @ loop count - 1\n" template <>
inline void int32_nchwc8_kernel(int32_t*& dout0, // NOLINT
int32_t*& dout1, // NOLINT
int32_t*& dout2, // NOLINT
int32_t*& dout3, // NOLINT
int32_t*& dout4, // NOLINT
int32_t*& dout5, // NOLINT
int32_t*& dout6, // NOLINT
int32_t*& dout7, // NOLINT
const int32_t*& din, // NOLINT
int cnt,
float32x4_t scale0,
float32x4_t scale1,
float32x4_t bias0,
float32x4_t bias1,
bool is_relu) {
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/
"trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/
"trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/
"trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/
"trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/
"trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/
"trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"cbz %w[relu], 2f\n"
"smax v16.4s, v16.4s, v20.4s \n" /*relu*/
"smax v17.4s, v17.4s, v20.4s \n" /*relu*/
"smax v18.4s, v18.4s, v20.4s \n" /*relu*/
"smax v19.4s, v19.4s, v20.4s \n" /*relu*/
"smax v8.4s, v8.4s, v20.4s \n" /*relu*/
"smax v9.4s, v9.4s, v20.4s \n" /*relu*/
"smax v12.4s, v12.4s, v20.4s \n" /*relu*/
"smax v13.4s, v13.4s, v20.4s \n" /*relu*/
"2:\n"
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/
"str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/
"str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/
"str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[doutc4r0] "+r"(dout4),
[doutc5r0] "+r"(dout5),
[doutc6r0] "+r"(dout6),
[doutc7r0] "+r"(dout7),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"vmov.s32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vtrn.32 q0, q2 @ trans q0, q2 \n"
"vtrn.32 q4, q6 @ trans q4, q6 \n"
"vswp.32 d1, d8 @ swap d1, d8 \n"
"vswp.32 d5, d12 @ swap d5, d12\n"
"vtrn.32 q1, q3 @ trans q1, q3 \n"
"vtrn.32 q5, q7 @ trans q5, q7 \n"
"vswp.32 d3, d10 @ swap d3, d10\n"
"vswp.32 d7, d14 @ swap d7, d14\n"
"cmp %[relu], #0\n"
"bne 2f\n"
"vmax.s32 q0, q0, q15 @ relu\n"
"vmax.s32 q1, q1, q15 @ relu\n"
"vmax.s32 q2, q2, q15 @ relu\n"
"vmax.s32 q3, q3, q15 @ relu\n"
"vmax.s32 q4, q4, q15 @ relu\n"
"vmax.s32 q5, q5, q15 @ relu\n"
"vmax.s32 q6, q6, q15 @ relu\n"
"vmax.s32 q7, q7, q15 @ relu\n"
"2:\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n"
"vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n"
"vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n"
"vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[doutc4r0] "+r"(dout4),
[doutc5r0] "+r"(dout5),
[doutc6r0] "+r"(dout6),
[doutc7r0] "+r"(dout7),
[ptr_din] "+r"(din)
: [cnt] "r"(cnt), [relu] "r"(is_relu)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q15");
#endif
}
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" /*wirte result in outputs
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q4", "q15");
#endif
}
if (we > width) {
int offset = 16 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
for (; i < width; ++i) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0);
*(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0);
din_hei_ptr += 4;
}
}
}
} else {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
int* doutc1_ptr = doutc1r0 + size_w;
int* doutc2_ptr = doutc2r0 + size_w;
int* doutc3_ptr = doutc3r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 3:
doutc1_ptr = trash_ptr;
case 2:
doutc2_ptr = trash_ptr;
case 1:
doutc3_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
: "v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"1: @ main loop\n"
"vtrn.32 q0, q1 @ trans q0, q1\n"
"vtrn.32 q2, q3 @ trans q2, q3\n"
"vswp.32 d1, d4 @ swap d1, d4 \n"
"vswp.32 d3, d6 @ swap d3, d6 \n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add "
"pointer\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q4", "q15");
#endif
}
if (we > width) {
int offset = 16 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
for (; i < width; ++i) {
*(doutc0_ptr++) = din_hei_ptr[0];
*(doutc1_ptr++) = din_hei_ptr[1];
*(doutc2_ptr++) = din_hei_ptr[2];
*(doutc3_ptr++) = din_hei_ptr[3];
din_hei_ptr += 4;
}
}
}
}
return true;
}
/*wirte result in outputs --int8, fp32
* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w]
*/
template <typename dtype>
inline bool write_to_output_c4_int32_1(const int* din,
dtype* dout,
int ch_n,
int hei_n,
int cs,
int ce,
int hs,
int he,
int ws,
int we,
int channel,
int height,
int width,
bool flag_relu,
dtype* trash_ptr,
const float* scale,
PrecisionType out_dtype) {
if (ch_n != 4 || hei_n <= 0) {
LOG(ERROR) << "ch_n must be equal 4 and hei_n is more than zero";
return false;
}
int size_c_out = width * height;
dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
dtype* doutc1r0 = doutc0r0 + size_c_out;
dtype* doutc2r0 = doutc1r0 + size_c_out;
dtype* doutc3r0 = doutc2r0 + size_c_out;
const int* ptr_din = din;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int valid_w = we - ws;
int cnt = valid_w / 4;
float32x4_t w_scale = vld1q_f32(scale);
// float32x4_t vzero = vdupq_n_f32(0.f);
if (we > width) {
cnt--;
}
if (out_dtype == PRECISION(kFloat)) {
// int32_to_fp32
if (flag_relu) {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
dtype* doutc1_ptr = doutc1r0 + size_w;
dtype* doutc2_ptr = doutc2r0 + size_w;
dtype* doutc3_ptr = doutc3r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 3:
doutc1_ptr = trash_ptr;
case 2:
doutc2_ptr = trash_ptr;
case 1:
doutc3_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"smax v16.4s, v16.4s, v20.4s \n" /* relu */
"smax v17.4s, v17.4s, v20.4s \n" /* relu */
"smax v18.4s, v18.4s, v20.4s \n" /* relu */
"smax v19.4s, v19.4s, v20.4s \n" /* relu */
// int32 --> fp32
"scvtf v4.4s, v16.4s \n"
"scvtf v5.4s, v17.4s \n"
"scvtf v6.4s, v18.4s \n"
"scvtf v7.4s, v19.4s \n"
// mul
"fmul v16.4s, v4.4s, %[scale].s[0] \n"
"fmul v17.4s, v5.4s, %[scale].s[2] \n"
"fmul v18.4s, v6.4s, %[scale].s[1] \n"
"fmul v19.4s, v7.4s, %[scale].s[3] \n"
// res
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
: [scale] "w"(w_scale)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vmov.u32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vtrn.32 q2, q3 @ trans q0, q1 \n"
"vtrn.32 q4, q5 @ trans q2, q3 \n"
"vswp.32 d5, d8 @ swap d1, d4 \n"
"vswp.32 d7, d10 @ swap d3, d6 \n"
"vmax.s32 q2, q2, q15 @ relu\n"
"vmax.s32 q3, q3, q15 @ relu\n"
"vmax.s32 q4, q4, q15 @ relu\n"
"vmax.s32 q5, q5, q15 @ relu\n"
// int32-> fp32
"vcvt.f32.s32 q6, q2 \n"
"vcvt.f32.s32 q7, q3 \n"
"vcvt.f32.s32 q8, q4 \n"
"vcvt.f32.s32 q9, q5 \n"
// mul
"vmul.f32 q2, q6, %e[scale][0] \n"
"vmul.f32 q3, q7, %e[scale][1] \n"
"vmul.f32 q4, q8, %f[scale][0] \n"
"vmul.f32 q5, q9, %f[scale][1] \n"
"vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d6-d7}, [%[doutc1r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d10-d11}, [%[doutc3r0]]! @ store result, add "
"pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
: [scale] "w"(w_scale)
: "q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
if (we > width) {
int offset = 16 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int j = we - 4;
for (; j < width; ++j) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0] * scale[0], 0);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1] * scale[1], 0);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2] * scale[2], 0);
*(doutc3_ptr++) = LITEMAX(din_hei_ptr[3] * scale[3], 0);
din_hei_ptr += 4;
}
}
}
} else {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
dtype* doutc1_ptr = doutc1r0 + size_w;
dtype* doutc2_ptr = doutc2r0 + size_w;
dtype* doutc3_ptr = doutc3r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 3:
doutc1_ptr = trash_ptr;
case 2:
doutc2_ptr = trash_ptr;
case 1:
doutc3_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
// int32 --> fp32
"scvtf v4.4s, v16.4s \n"
"scvtf v5.4s, v17.4s \n"
"scvtf v6.4s, v18.4s \n"
"scvtf v7.4s, v19.4s \n"
// mul
"fmul v16.4s, v4.4s, %[scale].s[0] \n"
"fmul v17.4s, v5.4s, %[scale].s[2] \n"
"fmul v18.4s, v6.4s, %[scale].s[1] \n"
"fmul v19.4s, v7.4s, %[scale].s[3] \n"
// res
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
: [scale] "w"(w_scale)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vmov.u32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vtrn.32 q2, q3 @ trans q0, q1 \n"
"vtrn.32 q4, q5 @ trans q2, q3 \n"
"vswp.32 d5, d8 @ swap d1, d4 \n"
"vswp.32 d7, d10 @ swap d3, d6 \n"
// int32-> fp32
"vcvt.f32.s32 q6, q2 \n"
"vcvt.f32.s32 q7, q3 \n"
"vcvt.f32.s32 q8, q4 \n"
"vcvt.f32.s32 q9, q5 \n"
// mul
"vmul.f32 q2, q6, %e[scale][0] \n"
"vmul.f32 q3, q7, %e[scale][1] \n"
"vmul.f32 q4, q8, %f[scale][0] \n"
"vmul.f32 q5, q9, %f[scale][1] \n"
"vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d6-d7}, [%[doutc1r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d10-d11}, [%[doutc3r0]]! @ store result, add "
"pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
: [scale] "w"(w_scale)
: "q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
if (we > width) {
int offset = 16 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int j = we - 4;
for (; j < width; ++j) {
*(doutc0_ptr++) = din_hei_ptr[0] * scale[0];
*(doutc1_ptr++) = din_hei_ptr[1] * scale[1];
*(doutc2_ptr++) = din_hei_ptr[2] * scale[2];
*(doutc3_ptr++) = din_hei_ptr[3] * scale[3];
din_hei_ptr += 4;
}
}
}
}
} else if (out_dtype == PRECISION(kInt8)) {
// int32_to_int8
if (flag_relu) {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
dtype* doutc1_ptr = doutc1r0 + size_w;
dtype* doutc2_ptr = doutc2r0 + size_w;
dtype* doutc3_ptr = doutc3r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 3:
doutc1_ptr = trash_ptr;
case 2:
doutc2_ptr = trash_ptr;
case 1:
doutc3_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"smax v16.4s, v16.4s, v20.4s \n" /* relu */
"smax v17.4s, v17.4s, v20.4s \n" /* relu */
"smax v18.4s, v18.4s, v20.4s \n" /* relu */
"smax v19.4s, v19.4s, v20.4s \n" /* relu */
// int32 --> fp32
"scvtf v4.4s, v16.4s \n"
"scvtf v5.4s, v17.4s \n"
"scvtf v6.4s, v18.4s \n"
"scvtf v7.4s, v19.4s \n"
// mul
"fmul v16.4s, v4.4s, %[scale].s[0] \n"
"fmul v17.4s, v5.4s, %[scale].s[2] \n"
"fmul v18.4s, v6.4s, %[scale].s[1] \n"
"fmul v19.4s, v7.4s, %[scale].s[3] \n"
// fp32-int32
"fcvtas v4.4s, v16.4s \n"
"fcvtas v5.4s, v17.4s \n"
"fcvtas v6.4s, v18.4s \n"
"fcvtas v7.4s, v19.4s \n"
// int32-int16
"sqxtn v8.4h, v4.4s \n"
"sqxtn v9.4h, v5.4s \n"
"sqxtn v10.4h, v6.4s \n"
"sqxtn v11.4h, v7.4s \n"
"sqxtn v16.8b, v8.8h \n"
"sqxtn v17.8b, v9.8h \n"
"sqxtn v18.8b, v10.8h \n"
"sqxtn v19.8b, v11.8h \n"
// res
"str s16, [%[doutc0r0]], #4 \n"
"str s17, [%[doutc2r0]], #4 \n"
"str s18, [%[doutc1r0]], #4 \n"
"str s19, [%[doutc3r0]], #4 \n"
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
: [scale] "w"(w_scale)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vmov.u32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vtrn.32 q2, q3 @ trans q0, q1 \n"
"vtrn.32 q4, q5 @ trans q2, q3 \n"
"vswp.32 d5, d8 @ swap d1, d4 \n"
"vswp.32 d7, d10 @ swap d3, d6 \n"
"vmax.s32 q2, q2, q15 @ relu\n"
"vmax.s32 q3, q3, q15 @ relu\n"
"vmax.s32 q4, q4, q15 @ relu\n"
"vmax.s32 q5, q5, q15 @ relu\n"
// int32-> fp32
"vcvt.f32.s32 q6, q2 \n"
"vcvt.f32.s32 q7, q3 \n"
"vcvt.f32.s32 q8, q4 \n"
"vcvt.f32.s32 q9, q5 \n"
"vmov.f32 q2, #0.5 \n"
// "vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q3, q2, q2 @ set offset, 0.5\n"
"vand.i32 q4, q2, q2 @ set offset, 0.5\n"
"vand.i32 q5, q2, q2 @ set offset, 0.5\n"
"vcgt.f32 q10, q6, q15 @ get mask > 0, in0\n"
"vcgt.f32 q11, q7, q15 @ get mask > 0, in1\n"
"vcgt.f32 q12, q8, q15 @ get mask > 0, in2\n"
"vcgt.f32 q13, q9, q15 @ get mask > 0, in3\n"
"vmov.f32 q15, #-0.5 \n"
"vbif.f32 q2, q15, q10 @ get right offset\n"
"vbif.f32 q3, q15, q11 @ get right offset\n"
"vbif.f32 q4, q15, q12 @ get right offset\n"
"vbif.f32 q5, q15, q13 @ get right offset\n"
"vmla.f32 q2, q6, %e[scale][0] @ mul scale\n"
"vmla.f32 q3, q7, %e[scale][1] @ mul scale\n"
"vmla.f32 q4, q8, %f[scale][0] @ mul scale\n"
"vmla.f32 q5, q9, %f[scale][1] @ mul scale\n"
"vcvt.s32.f32 q6, q2 @ cvt to int32\n"
"vcvt.s32.f32 q7, q3 @ cvt to int32\n"
"vcvt.s32.f32 q8, q4 @ cvt to int32\n"
"vcvt.s32.f32 q9, q5 @ cvt to int32\n"
"vqmovn.s32 d20, q6 @ cnt to int16\n"
"vqmovn.s32 d22, q7 @ cnt to int16\n"
"vqmovn.s32 d24, q8 @ cnt to int16\n"
"vqmovn.s32 d26, q9 @ cnt to int16\n"
"vqmovn.s16 d8, q10 @ cnt to int8\n"
"vqmovn.s16 d9, q11 @ cnt to int8\n"
"vqmovn.s16 d10, q12 @ cnt to int8\n"
"vqmovn.s16 d11, q13 @ cnt to int8\n"
"vst1.32 {d8[0]}, [%[doutc0r0]] @ write to output\n"
"vst1.32 {d9[0]}, [%[doutc1r0]] @ write to output\n"
"vst1.32 {d10[0]}, [%[doutc2r0]] @ write to output\n"
"vst1.32 {d11[0]}, [%[doutc3r0]] @ write to output\n"
"add %[doutc0r0], #4 \n"
"add %[doutc1r0], #4 \n"
"add %[doutc2r0], #4 \n"
"add %[doutc3r0], #4 \n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vmov.u32 q15, #0 @ dump zero\n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
: [scale] "w"(w_scale)
: "q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
if (we > width) {
int offset = 16 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int j = we - 4;
for (; j < width; ++j) {
*(doutc0_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[0], 0) * scale[0]));
*(doutc1_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[1], 0) * scale[1]));
*(doutc2_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[2], 0) * scale[2]));
*(doutc3_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[3], 0) * scale[3]));
din_hei_ptr += 4;
}
}
}
} else {
for (int i = 0; i < size_h; i++) { // size_h
int size_w = i * width;
dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
dtype* doutc1_ptr = doutc1r0 + size_w;
dtype* doutc2_ptr = doutc2r0 + size_w;
dtype* doutc3_ptr = doutc3r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 3:
doutc1_ptr = trash_ptr;
case 2:
doutc2_ptr = trash_ptr;
case 1:
doutc3_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
// int32 --> fp32
"scvtf v4.4s, v16.4s \n"
"scvtf v5.4s, v17.4s \n"
"scvtf v6.4s, v18.4s \n"
"scvtf v7.4s, v19.4s \n"
// mul
"fmul v16.4s, v4.4s, %[scale].s[0] \n"
"fmul v17.4s, v5.4s, %[scale].s[2] \n"
"fmul v18.4s, v6.4s, %[scale].s[1] \n"
"fmul v19.4s, v7.4s, %[scale].s[3] \n"
// fp32-int32
"fcvtas v4.4s, v16.4s \n"
"fcvtas v5.4s, v17.4s \n"
"fcvtas v6.4s, v18.4s \n"
"fcvtas v7.4s, v19.4s \n"
// int32-int16
"sqxtn v8.4h, v4.4s \n"
"sqxtn v9.4h, v5.4s \n"
"sqxtn v10.4h, v6.4s \n"
"sqxtn v11.4h, v7.4s \n"
"sqxtn v16.8b, v8.8h \n"
"sqxtn v17.8b, v9.8h \n"
"sqxtn v18.8b, v10.8h \n"
"sqxtn v19.8b, v11.8h \n"
// res
"str s16, [%[doutc0r0]], #4 \n"
"str s17, [%[doutc2r0]], #4 \n"
"str s18, [%[doutc1r0]], #4 \n"
"str s19, [%[doutc3r0]], #4 \n"
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
: [scale] "w"(w_scale)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vmov.u32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vtrn.32 q2, q3 @ trans q0, q1 \n"
"vtrn.32 q4, q5 @ trans q2, q3 \n"
"vswp.32 d5, d8 @ swap d1, d4 \n"
"vswp.32 d7, d10 @ swap d3, d6 \n"
// int32-> fp32
"vcvt.f32.s32 q6, q2 \n"
"vcvt.f32.s32 q7, q3 \n"
"vcvt.f32.s32 q8, q4 \n"
"vcvt.f32.s32 q9, q5 \n"
"vmov.f32 q2, #0.5 \n"
// "vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q3, q2, q2 @ set offset, 0.5\n"
"vand.i32 q4, q2, q2 @ set offset, 0.5\n"
"vand.i32 q5, q2, q2 @ set offset, 0.5\n"
"vcgt.f32 q10, q6, q15 @ get mask > 0, in0\n"
"vcgt.f32 q11, q7, q15 @ get mask > 0, in1\n"
"vcgt.f32 q12, q8, q15 @ get mask > 0, in2\n"
"vcgt.f32 q13, q9, q15 @ get mask > 0, in3\n"
"vmov.f32 q15, #-0.5 \n"
"vbif.f32 q2, q15, q10 @ get right offset\n"
"vbif.f32 q3, q15, q11 @ get right offset\n"
"vbif.f32 q4, q15, q12 @ get right offset\n"
"vbif.f32 q5, q15, q13 @ get right offset\n"
"vmla.f32 q2, q6, %e[scale][0] @ mul scale\n"
"vmla.f32 q3, q7, %e[scale][1] @ mul scale\n"
"vmla.f32 q4, q8, %f[scale][0] @ mul scale\n"
"vmla.f32 q5, q9, %f[scale][1] @ mul scale\n"
"vcvt.s32.f32 q6, q2 @ cvt to int32\n"
"vcvt.s32.f32 q7, q3 @ cvt to int32\n"
"vcvt.s32.f32 q8, q4 @ cvt to int32\n"
"vcvt.s32.f32 q9, q5 @ cvt to int32\n"
"vqmovn.s32 d20, q6 @ cnt to int16\n"
"vqmovn.s32 d22, q7 @ cnt to int16\n"
"vqmovn.s32 d24, q8 @ cnt to int16\n"
"vqmovn.s32 d26, q9 @ cnt to int16\n"
"vqmovn.s16 d8, q10 @ cnt to int8\n"
"vqmovn.s16 d9, q11 @ cnt to int8\n"
"vqmovn.s16 d10, q12 @ cnt to int8\n"
"vqmovn.s16 d11, q13 @ cnt to int8\n"
"vst1.32 {d8[0]}, [%[doutc0r0]] @ write to output\n"
"vst1.32 {d9[0]}, [%[doutc1r0]] @ write to output\n"
"vst1.32 {d10[0]}, [%[doutc2r0]] @ write to output\n"
"vst1.32 {d11[0]}, [%[doutc3r0]] @ write to output\n"
"add %[doutc0r0], #4 \n"
"add %[doutc1r0], #4 \n"
"add %[doutc2r0], #4 \n"
"add %[doutc3r0], #4 \n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vmov.u32 q15, #0 @ dump zero\n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
: [scale] "w"(w_scale)
: "q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
if (we > width) {
int offset = 16 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int j = we - 4;
for (; j < width; ++j) {
*(doutc0_ptr++) =
saturate_cast<int8_t>(roundf(din_hei_ptr[0] * scale[0]));
*(doutc1_ptr++) =
saturate_cast<int8_t>(roundf(din_hei_ptr[1] * scale[1]));
*(doutc2_ptr++) =
saturate_cast<int8_t>(roundf(din_hei_ptr[2] * scale[2]));
*(doutc3_ptr++) =
saturate_cast<int8_t>(roundf(din_hei_ptr[3] * scale[3]));
din_hei_ptr += 4;
}
}
}
}
} else {
LOG(ERROR) << "ERROR: unsupported input data type!!";
return false;
}
return true;
}
/*wirte result in outputs
* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w]
*/
inline bool write_to_output_c8_int32(const int* din,
int* dout,
int ch_n,
int hei_n,
int cs,
int ce,
int hs,
int he,
int ws,
int we,
int channel,
int height,
int width,
bool flag_relu,
int* trash_ptr) {
if (ch_n != 8 || hei_n <= 0) {
LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero";
return false;
}
int size_c_out = width * height;
int* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
int* doutc1r0 = doutc0r0 + size_c_out;
int* doutc2r0 = doutc1r0 + size_c_out;
int* doutc3r0 = doutc2r0 + size_c_out;
int* doutc4r0 = doutc3r0 + size_c_out;
int* doutc5r0 = doutc4r0 + size_c_out;
int* doutc6r0 = doutc5r0 + size_c_out;
int* doutc7r0 = doutc6r0 + size_c_out;
const int* ptr_din = din;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int valid_w = we - ws;
int cnt = valid_w / 4;
if (we > width) {
cnt--;
}
if (flag_relu) {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
int* doutc1_ptr = doutc1r0 + size_w;
int* doutc2_ptr = doutc2r0 + size_w;
int* doutc3_ptr = doutc3r0 + size_w;
int* doutc4_ptr = doutc4r0 + size_w;
int* doutc5_ptr = doutc5r0 + size_w;
int* doutc6_ptr = doutc6r0 + size_w;
int* doutc7_ptr = doutc7r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 7:
doutc1_ptr = trash_ptr;
case 6:
doutc2_ptr = trash_ptr;
case 5:
doutc3_ptr = trash_ptr;
case 4:
doutc4_ptr = trash_ptr;
case 3:
doutc5_ptr = trash_ptr;
case 2:
doutc6_ptr = trash_ptr;
case 1:
doutc7_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/
"trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/
"trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/
"trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/
"trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/
"trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/
"trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"smax v16.4s, v16.4s, v20.4s \n" /*relu*/
"smax v17.4s, v17.4s, v20.4s \n" /*relu*/
"smax v18.4s, v18.4s, v20.4s \n" /*relu*/
"smax v19.4s, v19.4s, v20.4s \n" /*relu*/
"smax v8.4s, v8.4s, v20.4s \n" /*relu*/
"smax v9.4s, v9.4s, v20.4s \n" /*relu*/
"smax v12.4s, v12.4s, v20.4s \n" /*relu*/
"smax v13.4s, v13.4s, v20.4s \n" /*relu*/
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/
"str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/
"str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/
"str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"vmov.s32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vtrn.32 q0, q2 @ trans q0, q2 \n"
"vtrn.32 q4, q6 @ trans q4, q6 \n"
"vswp.32 d1, d8 @ swap d1, d8 \n"
"vswp.32 d5, d12 @ swap d5, d12\n"
"vtrn.32 q1, q3 @ trans q1, q3 \n"
"vtrn.32 q5, q7 @ trans q5, q7 \n"
"vswp.32 d3, d10 @ swap d3, d10\n"
"vswp.32 d7, d14 @ swap d7, d14\n"
"vmax.s32 q0, q0, q15 @ relu\n"
"vmax.s32 q1, q1, q15 @ relu\n"
"vmax.s32 q2, q2, q15 @ relu\n"
"vmax.s32 q3, q3, q15 @ relu\n"
"vmax.s32 q4, q4, q15 @ relu\n"
"vmax.s32 q5, q5, q15 @ relu\n"
"vmax.s32 q6, q6, q15 @ relu\n"
"vmax.s32 q7, q7, q15 @ relu\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n"
"vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n"
"vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n"
"vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q4", "q15");
#endif
}
if (we > width) {
int offset = 32 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
for (; i < width; ++i) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0);
*(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0);
*(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0);
*(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0);
*(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0);
*(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0);
din_hei_ptr += 8;
}
}
}
} else {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
int* doutc1_ptr = doutc1r0 + size_w;
int* doutc2_ptr = doutc2r0 + size_w;
int* doutc3_ptr = doutc3r0 + size_w;
int* doutc4_ptr = doutc4r0 + size_w;
int* doutc5_ptr = doutc5r0 + size_w;
int* doutc6_ptr = doutc6r0 + size_w;
int* doutc7_ptr = doutc7r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 7:
doutc1_ptr = trash_ptr;
case 6:
doutc2_ptr = trash_ptr;
case 5:
doutc3_ptr = trash_ptr;
case 4:
doutc4_ptr = trash_ptr;
case 3:
doutc5_ptr = trash_ptr;
case 2:
doutc6_ptr = trash_ptr;
case 1:
doutc7_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/
"trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/
"trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/
"trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/
"trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/
"trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/
"trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/
"str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/
"str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/
"str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
: "v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"1: @ main loop\n"
"vtrn.32 q0, q2 @ trans q0, q2 \n"
"vtrn.32 q4, q6 @ trans q4, q6 \n"
"vswp.32 d1, d8 @ swap d1, d8 \n"
"vswp.32 d5, d12 @ swap d5, d12\n"
"vtrn.32 q1, q3 @ trans q1, q3 \n"
"vtrn.32 q5, q7 @ trans q5, q7 \n"
"vswp.32 d3, d10 @ swap d3, d10\n"
"vswp.32 d7, d14 @ swap d7, d14\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n"
"vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n"
"vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n"
"vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q4", "q15");
#endif
}
if (we > width) {
int offset = 32 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
for (; i < width; ++i) {
*(doutc0_ptr++) = din_hei_ptr[0];
*(doutc1_ptr++) = din_hei_ptr[1];
*(doutc2_ptr++) = din_hei_ptr[2];
*(doutc3_ptr++) = din_hei_ptr[3];
*(doutc4_ptr++) = din_hei_ptr[4];
*(doutc5_ptr++) = din_hei_ptr[5];
*(doutc6_ptr++) = din_hei_ptr[6];
*(doutc7_ptr++) = din_hei_ptr[7];
din_hei_ptr += 8;
}
}
}
}
return true;
}
/*wirte result in outputs--int8, fp32
* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] * input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w]
*/ */
template <typename dtype> template <typename Dtype>
static bool write_to_output_c8_int32_1(const int* din, inline void write_int32_nchwc8_to_nchw(const int* din,
dtype* dout, Dtype* dout,
int ch_n,
int hei_n,
int cs, int cs,
int ce, int ce,
int hs, int hs,
...@@ -2769,1307 +2806,126 @@ static bool write_to_output_c8_int32_1(const int* din, ...@@ -2769,1307 +2806,126 @@ static bool write_to_output_c8_int32_1(const int* din,
int height, int height,
int width, int width,
bool flag_relu, bool flag_relu,
dtype* trash_ptr, float* bias,
const float* scale, bool flag_bias,
PrecisionType out_dtype) { Dtype* trash_ptr,
if (ch_n != 8 || hei_n <= 0) { const float* scale) {
LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero";
return false;
}
int size_c_out = width * height; int size_c_out = width * height;
dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws; Dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
dtype* doutc1r0 = doutc0r0 + size_c_out; Dtype* doutc1r0 = doutc0r0 + size_c_out;
dtype* doutc2r0 = doutc1r0 + size_c_out; Dtype* doutc2r0 = doutc1r0 + size_c_out;
dtype* doutc3r0 = doutc2r0 + size_c_out; Dtype* doutc3r0 = doutc2r0 + size_c_out;
dtype* doutc4r0 = doutc3r0 + size_c_out; Dtype* doutc4r0 = doutc3r0 + size_c_out;
dtype* doutc5r0 = doutc4r0 + size_c_out; Dtype* doutc5r0 = doutc4r0 + size_c_out;
dtype* doutc6r0 = doutc5r0 + size_c_out; Dtype* doutc6r0 = doutc5r0 + size_c_out;
dtype* doutc7r0 = doutc6r0 + size_c_out; Dtype* doutc7r0 = doutc6r0 + size_c_out;
const int* ptr_din = din; const int* ptr_din = din;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int valid_w = we - ws; int w_stride = we - ws;
int valid_w = (we > width ? width : we) - ws;
int cnt = valid_w / 4; int cnt = valid_w / 4;
float32x4_t w_scale0 = vld1q_f32(scale); float32x4_t w_scale0 = vld1q_f32(scale);
float32x4_t w_scale1 = vld1q_f32(scale + 4); float32x4_t w_scale1 = vld1q_f32(scale + 4);
float32x4_t w_bias0 = flag_bias ? vld1q_f32(bias) : vdupq_n_f32(0.f);
float32x4_t w_bias1 = flag_bias ? vld1q_f32(bias + 4) : vdupq_n_f32(0.f);
float32x4_t vzero = vdupq_n_f32(0.f); for (int i = 0; i < size_h; i++) {
int size_w = i * width;
if (we > width) { Dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
cnt--; Dtype* doutc1_ptr = doutc1r0 + size_w;
} Dtype* doutc2_ptr = doutc2r0 + size_w;
if (out_dtype == PRECISION(kFloat)) { Dtype* doutc3_ptr = doutc3r0 + size_w;
if (flag_relu) { Dtype* doutc4_ptr = doutc4r0 + size_w;
for (int i = 0; i < size_h; i++) { Dtype* doutc5_ptr = doutc5r0 + size_w;
int size_w = i * width; Dtype* doutc6_ptr = doutc6r0 + size_w;
dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; Dtype* doutc7_ptr = doutc7r0 + size_w;
dtype* doutc1_ptr = doutc1r0 + size_w; if (ce > channel) {
dtype* doutc2_ptr = doutc2r0 + size_w; switch (ce - channel) {
dtype* doutc3_ptr = doutc3r0 + size_w; case 7:
dtype* doutc4_ptr = doutc4r0 + size_w; doutc1_ptr = trash_ptr;
dtype* doutc5_ptr = doutc5r0 + size_w; case 6:
dtype* doutc6_ptr = doutc6r0 + size_w; doutc2_ptr = trash_ptr;
dtype* doutc7_ptr = doutc7r0 + size_w; case 5:
if (ce > channel) { doutc3_ptr = trash_ptr;
switch (ce - channel) { case 4:
case 7: doutc4_ptr = trash_ptr;
doutc1_ptr = trash_ptr; case 3:
case 6: doutc5_ptr = trash_ptr;
doutc2_ptr = trash_ptr; case 2:
case 5: doutc6_ptr = trash_ptr;
doutc3_ptr = trash_ptr; case 1:
case 4: doutc7_ptr = trash_ptr;
doutc4_ptr = trash_ptr; default:
case 3: break;
doutc5_ptr = trash_ptr;
case 2:
doutc6_ptr = trash_ptr;
case 1:
doutc7_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/
"trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/
"trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/
"trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/
"trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/
"trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/
"trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"smax v16.4s, v16.4s, v20.4s \n" /*relu*/
"smax v17.4s, v17.4s, v20.4s \n" /*relu*/
"smax v18.4s, v18.4s, v20.4s \n" /*relu*/
"smax v19.4s, v19.4s, v20.4s \n" /*relu*/
"smax v8.4s, v8.4s, v20.4s \n" /*relu*/
"smax v9.4s, v9.4s, v20.4s \n" /*relu*/
"smax v12.4s, v12.4s, v20.4s \n" /*relu*/
"smax v13.4s, v13.4s, v20.4s \n" /*relu*/
// int32->fp32
"scvtf v10.4s, v16.4s \n"
"scvtf v11.4s, v17.4s \n"
"scvtf v14.4s, v18.4s \n"
"scvtf v15.4s, v19.4s \n"
// mul
"fmul v16.4s, v10.4s, %[scale0].s[0] \n"
"fmul v17.4s, v11.4s, %[scale0].s[2] \n"
"fmul v18.4s, v14.4s, %[scale0].s[1] \n"
"fmul v19.4s, v15.4s, %[scale0].s[3] \n"
"scvtf v10.4s, v8.4s \n"
"scvtf v11.4s, v9.4s \n"
"scvtf v14.4s, v12.4s \n"
"scvtf v15.4s, v13.4s \n"
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
// mul
"fmul v8.4s, v10.4s, %[scale1].s[0] \n"
"fmul v9.4s, v11.4s, %[scale1].s[2] \n"
"fmul v12.4s, v14.4s, %[scale1].s[1] \n"
"fmul v13.4s, v15.4s, %[scale1].s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/
"str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/
"str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/
"str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
: [scale0] "w"(w_scale0), [scale1] "w"(w_scale1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"vmov.s32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vmax.s32 q0, q0, q15 @ relu\n"
"vmax.s32 q1, q1, q15 @ relu\n"
"vmax.s32 q2, q2, q15 @ relu\n"
"vmax.s32 q3, q3, q15 @ relu\n"
"vmax.s32 q4, q4, q15 @ relu\n"
"vmax.s32 q5, q5, q15 @ relu\n"
"vmax.s32 q6, q6, q15 @ relu\n"
"vmax.s32 q7, q7, q15 @ relu\n"
// int32-> fp32
"vcvt.f32.s32 q8, q0 \n"
"vcvt.f32.s32 q9, q1 \n"
"vcvt.f32.s32 q10, q2 \n"
"vcvt.f32.s32 q11, q3 \n"
// mul
"vmul.f32 q0, q8, %q[scale0] \n"
"vmul.f32 q1, q9, %q[scale1] \n"
"vmul.f32 q2, q10, %q[scale0] \n"
"vmul.f32 q3, q11, %q[scale1] \n"
// int32-> fp32
"vcvt.f32.s32 q8, q4 \n"
"vcvt.f32.s32 q9, q5 \n"
"vcvt.f32.s32 q10, q6 \n"
"vcvt.f32.s32 q11, q7 \n"
// mul
"vmul.f32 q4, q8, %q[scale0] \n"
"vmul.f32 q5, q9, %q[scale1] \n"
"vmul.f32 q6, q10, %q[scale0] \n"
"vmul.f32 q7, q11, %q[scale1] \n"
"vtrn.32 q0, q2 @ trans q0, q2 \n"
"vtrn.32 q4, q6 @ trans q4, q6 \n"
"vswp.32 d1, d8 @ swap d1, d8 \n"
"vswp.32 d5, d12 @ swap d5, d12\n"
"vtrn.32 q1, q3 @ trans q1, q3 \n"
"vtrn.32 q5, q7 @ trans q5, q7 \n"
"vswp.32 d3, d10 @ swap d3, d10\n"
"vswp.32 d7, d14 @ swap d7, d14\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n"
"vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add "
"pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
: [scale0] "w"(w_scale0), [scale1] "w"(w_scale1)
: "q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q15");
#endif
}
if (we > width) {
int offset = 32 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
for (; i < width; ++i) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0] * scale[0], 0);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1] * scale[1], 0);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2] * scale[2], 0);
*(doutc3_ptr++) = LITEMAX(din_hei_ptr[3] * scale[3], 0);
*(doutc4_ptr++) = LITEMAX(din_hei_ptr[4] * scale[4], 0);
*(doutc5_ptr++) = LITEMAX(din_hei_ptr[5] * scale[5], 0);
*(doutc6_ptr++) = LITEMAX(din_hei_ptr[6] * scale[6], 0);
*(doutc7_ptr++) = LITEMAX(din_hei_ptr[7] * scale[7], 0);
din_hei_ptr += 8;
}
}
}
} else {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
dtype* doutc1_ptr = doutc1r0 + size_w;
dtype* doutc2_ptr = doutc2r0 + size_w;
dtype* doutc3_ptr = doutc3r0 + size_w;
dtype* doutc4_ptr = doutc4r0 + size_w;
dtype* doutc5_ptr = doutc5r0 + size_w;
dtype* doutc6_ptr = doutc6r0 + size_w;
dtype* doutc7_ptr = doutc7r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 7:
doutc1_ptr = trash_ptr;
case 6:
doutc2_ptr = trash_ptr;
case 5:
doutc3_ptr = trash_ptr;
case 4:
doutc4_ptr = trash_ptr;
case 3:
doutc5_ptr = trash_ptr;
case 2:
doutc6_ptr = trash_ptr;
case 1:
doutc7_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/
"trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/
"trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/
"trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/
"trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/
"trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/
"trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
// int32->fp32
"scvtf v10.4s, v16.4s \n"
"scvtf v11.4s, v17.4s \n"
"scvtf v14.4s, v18.4s \n"
"scvtf v15.4s, v19.4s \n"
// mul
"fmul v16.4s, v10.4s, %[scale0].s[0] \n"
"fmul v17.4s, v11.4s, %[scale0].s[2] \n"
"fmul v18.4s, v14.4s, %[scale0].s[1] \n"
"fmul v19.4s, v15.4s, %[scale0].s[3] \n"
"scvtf v10.4s, v8.4s \n"
"scvtf v11.4s, v9.4s \n"
"scvtf v14.4s, v12.4s \n"
"scvtf v15.4s, v13.4s \n"
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
// mul
"fmul v8.4s, v10.4s, %[scale1].s[0] \n"
"fmul v9.4s, v11.4s, %[scale1].s[2] \n"
"fmul v12.4s, v14.4s, %[scale1].s[1] \n"
"fmul v13.4s, v15.4s, %[scale1].s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/
"str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/
"str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/
"str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
: [scale0] "w"(w_scale0), [scale1] "w"(w_scale1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"vmov.s32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
// int32-> fp32
"vcvt.f32.s32 q8, q0 \n"
"vcvt.f32.s32 q9, q1 \n"
"vcvt.f32.s32 q10, q2 \n"
"vcvt.f32.s32 q11, q3 \n"
// mul
"vmul.f32 q0, q8, %q[scale0] \n"
"vmul.f32 q1, q9, %q[scale1] \n"
"vmul.f32 q2, q10, %q[scale0] \n"
"vmul.f32 q3, q11, %q[scale1] \n"
// int32-> fp32
"vcvt.f32.s32 q8, q4 \n"
"vcvt.f32.s32 q9, q5 \n"
"vcvt.f32.s32 q10, q6 \n"
"vcvt.f32.s32 q11, q7 \n"
// mul
"vmul.f32 q4, q8, %q[scale0] \n"
"vmul.f32 q5, q9, %q[scale1] \n"
"vmul.f32 q6, q10, %q[scale0] \n"
"vmul.f32 q7, q11, %q[scale1] \n"
"vtrn.32 q0, q2 @ trans q0, q2 \n"
"vtrn.32 q4, q6 @ trans q4, q6 \n"
"vswp.32 d1, d8 @ swap d1, d8 \n"
"vswp.32 d5, d12 @ swap d5, d12\n"
"vtrn.32 q1, q3 @ trans q1, q3 \n"
"vtrn.32 q5, q7 @ trans q5, q7 \n"
"vswp.32 d3, d10 @ swap d3, d10\n"
"vswp.32 d7, d14 @ swap d7, d14\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n"
"vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add "
"pointer\n"
"vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add "
"pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
: [scale0] "w"(w_scale0), [scale1] "w"(w_scale1)
: "q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q15");
#endif
}
if (we > width) {
int offset = 32 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
for (; i < width; ++i) {
*(doutc0_ptr++) = din_hei_ptr[0] * scale[0];
*(doutc1_ptr++) = din_hei_ptr[1] * scale[1];
*(doutc2_ptr++) = din_hei_ptr[2] * scale[2];
*(doutc3_ptr++) = din_hei_ptr[3] * scale[3];
*(doutc4_ptr++) = din_hei_ptr[4] * scale[4];
*(doutc5_ptr++) = din_hei_ptr[5] * scale[5];
*(doutc6_ptr++) = din_hei_ptr[6] * scale[6];
*(doutc7_ptr++) = din_hei_ptr[7] * scale[7];
din_hei_ptr += 8;
}
}
} }
} }
} else if (out_dtype == PRECISION(kInt8)) { ptr_din = din + i * w_stride * 8;
// int32_to_int8 const int* din_hei_ptr = ptr_din;
float32x4_t vpoff = vdupq_n_f32(0.5f); if (cnt > 0) {
float32x4_t vnoff = vdupq_n_f32(-0.5f); int32_nchwc8_kernel(doutc0_ptr,
if (flag_relu) { doutc1_ptr,
for (int i = 0; i < size_h; i++) { doutc2_ptr,
int size_w = i * width; doutc3_ptr,
dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; doutc4_ptr,
dtype* doutc1_ptr = doutc1r0 + size_w; doutc5_ptr,
dtype* doutc2_ptr = doutc2r0 + size_w; doutc6_ptr,
dtype* doutc3_ptr = doutc3r0 + size_w; doutc7_ptr,
dtype* doutc4_ptr = doutc4r0 + size_w; din_hei_ptr,
dtype* doutc5_ptr = doutc5r0 + size_w; cnt,
dtype* doutc6_ptr = doutc6r0 + size_w; w_scale0,
dtype* doutc7_ptr = doutc7r0 + size_w; w_scale1,
if (ce > channel) { w_bias0,
switch (ce - channel) { w_bias1,
case 7: flag_relu);
doutc1_ptr = trash_ptr; }
case 6: if (we > width) {
doutc2_ptr = trash_ptr; int offset = 32 * cnt;
case 5: din_hei_ptr = ptr_din + offset;
doutc3_ptr = trash_ptr; for (int j = ws + cnt * 4; j < width; ++j) {
case 4: if (flag_bias) {
doutc4_ptr = trash_ptr; *(doutc0_ptr++) =
case 3: cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], bias[0], flag_relu);
doutc5_ptr = trash_ptr; *(doutc1_ptr++) =
case 2: cvt_kernel<Dtype>(din_hei_ptr[1], scale[1], bias[1], flag_relu);
doutc6_ptr = trash_ptr; *(doutc2_ptr++) =
case 1: cvt_kernel<Dtype>(din_hei_ptr[2], scale[2], bias[2], flag_relu);
doutc7_ptr = trash_ptr; *(doutc3_ptr++) =
default: cvt_kernel<Dtype>(din_hei_ptr[3], scale[3], bias[3], flag_relu);
break; *(doutc4_ptr++) =
} cvt_kernel<Dtype>(din_hei_ptr[4], scale[4], bias[4], flag_relu);
} *(doutc5_ptr++) =
ptr_din = din + i * valid_w * ch_n; cvt_kernel<Dtype>(din_hei_ptr[5], scale[5], bias[5], flag_relu);
const int* din_hei_ptr = ptr_din; *(doutc6_ptr++) =
if (cnt > 0) { cvt_kernel<Dtype>(din_hei_ptr[6], scale[6], bias[6], flag_relu);
int cnt_loop = cnt; *(doutc7_ptr++) =
#ifdef __aarch64__ cvt_kernel<Dtype>(din_hei_ptr[7], scale[7], bias[7], flag_relu);
asm volatile( } else {
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ *(doutc0_ptr++) =
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], 0.f, flag_relu);
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ *(doutc1_ptr++) =
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ cvt_kernel<Dtype>(din_hei_ptr[1], scale[1], 0.f, flag_relu);
// "movi v20.4s, #0 \n" /* for relu */ *(doutc2_ptr++) =
"1: \n" /* main loop*/ cvt_kernel<Dtype>(din_hei_ptr[2], scale[2], 0.f, flag_relu);
"trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ *(doutc3_ptr++) =
"trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ cvt_kernel<Dtype>(din_hei_ptr[3], scale[3], 0.f, flag_relu);
"trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ *(doutc4_ptr++) =
"trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ cvt_kernel<Dtype>(din_hei_ptr[4], scale[4], 0.f, flag_relu);
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ *(doutc5_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[5], scale[5], 0.f, flag_relu);
"trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ *(doutc6_ptr++) =
"trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ cvt_kernel<Dtype>(din_hei_ptr[6], scale[6], 0.f, flag_relu);
"trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ *(doutc7_ptr++) =
"trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ cvt_kernel<Dtype>(din_hei_ptr[7], scale[7], 0.f, flag_relu);
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/
"trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/
"trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/
"trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/
"trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/
"trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/
"trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"smax v16.4s, v16.4s, %[vzero].4s \n" /*relu*/
"smax v17.4s, v17.4s, %[vzero].4s \n" /*relu*/
"smax v18.4s, v18.4s, %[vzero].4s \n" /*relu*/
"smax v19.4s, v19.4s, %[vzero].4s \n" /*relu*/
"smax v8.4s, v8.4s, %[vzero].4s \n" /*relu*/
"smax v9.4s, v9.4s, %[vzero].4s \n" /*relu*/
"smax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/
"smax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/
// int32 --> fp32
"scvtf v10.4s, v16.4s \n"
"scvtf v11.4s, v17.4s \n"
"scvtf v14.4s, v18.4s \n"
"scvtf v15.4s, v19.4s \n"
"scvtf v20.4s, v8.4s \n"
"scvtf v21.4s, v9.4s \n"
"scvtf v22.4s, v12.4s \n"
"scvtf v23.4s, v13.4s \n"
// mul
"fmul v16.4s, v10.4s, %[scale0].s[0] \n"
"fmul v17.4s, v11.4s, %[scale0].s[2] \n"
"fmul v18.4s, v14.4s, %[scale0].s[1] \n"
"fmul v19.4s, v15.4s, %[scale0].s[3] \n"
"fmul v8.4s, v20.4s, %[scale1].s[0] \n"
"fmul v9.4s, v21.4s, %[scale1].s[2] \n"
"fmul v12.4s, v22.4s, %[scale1].s[1] \n"
"fmul v13.4s, v23.4s, %[scale1].s[3] \n"
// fp32-int32
"fcvtas v10.4s, v16.4s \n"
"fcvtas v11.4s, v17.4s \n"
"fcvtas v14.4s, v18.4s \n"
"fcvtas v15.4s, v19.4s \n"
"fcvtas v20.4s, v8.4s \n"
"fcvtas v21.4s, v9.4s \n"
"fcvtas v22.4s, v12.4s \n"
"fcvtas v23.4s, v13.4s \n"
// int32-int16
"sqxtn v16.4h, v10.4s \n"
"sqxtn v17.4h, v11.4s \n"
"sqxtn v18.4h, v14.4s \n"
"sqxtn v19.4h, v15.4s \n"
"sqxtn v8.4h, v20.4s \n"
"sqxtn v9.4h, v21.4s \n"
"sqxtn v12.4h, v22.4s \n"
"sqxtn v13.4h, v23.4s \n"
// int16-int8
"sqxtn v10.8b, v16.8h \n"
"sqxtn v11.8b, v17.8h \n"
"sqxtn v14.8b, v18.8h \n"
"sqxtn v15.8b, v19.8h \n"
"sqxtn v20.8b, v8.8h \n"
"sqxtn v21.8b, v9.8h \n"
"sqxtn v22.8b, v12.8h \n"
"sqxtn v23.8b, v13.8h \n"
"str s10, [%[doutc0r0]], #4 \n" /* store c0r0*/
"str s11, [%[doutc2r0]], #4 \n" /* store c2r0*/
"str s14, [%[doutc1r0]], #4 \n" /* store c1r0*/
"str s15, [%[doutc3r0]], #4 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str s20, [%[doutc4r0]], #4 \n" /* store c0r0*/
"str s21, [%[doutc6r0]], #4 \n" /* store c2r0*/
"str s22, [%[doutc5r0]], #4 \n" /* store c1r0*/
"str s23, [%[doutc7r0]], #4 \n" /* store c3r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
[scale0] "w"(w_scale0), [scale1] "w"(w_scale1), [vzero] "w"(vzero)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23");
#else
asm volatile(
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"1: @ main loop\n"
"vmax.s32 q4, q4, %q[vzero] @ relu\n"
"vmax.s32 q5, q5, %q[vzero] @ relu\n"
"vmax.s32 q6, q6, %q[vzero] @ relu\n"
"vmax.s32 q7, q7, %q[vzero] @ relu\n"
// int32-> fp32
"vmov.f32 q15, #0.5 \n"
"vcvt.f32.s32 q8, q4 \n"
"vcvt.f32.s32 q9, q5 \n"
"vcvt.f32.s32 q10, q6 \n"
"vcvt.f32.s32 q11, q7 \n"
"vand.i32 q4, q15, q15 @ set offset, 0.5\n"
"vand.i32 q5, q15, q15 @ set offset, 0.5\n"
"vand.i32 q6, q15, q15 @ set offset, 0.5\n"
"vand.i32 q7, q15, q15 @ set offset, 0.5\n"
"vmov.f32 q15, #-0.5 \n"
"vcgt.f32 q12, q8, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q13, q9, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q14, q10, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q3, q11, %q[vzero] @ get mask > 0, in0\n"
"vbif.f32 q4, q15, q12 @ get right offset\n"
"vbif.f32 q5, q15, q13 @ get right offset\n"
"vbif.f32 q6, q15, q14 @ get right offset\n"
"vbif.f32 q7, q15, q3 @ get right offset\n"
"vld1.32 {d24-d27}, [%[ptr_din]]! @load data \n"
"vld1.32 {d28-d29}, [%[ptr_din]]! @load data \n"
"vld1.32 {d6-d7}, [%[ptr_din]]! @load data \n"
"vmla.f32 q4, q8, %q[scale0] @ mul scale\n"
"vmla.f32 q5, q9, %q[scale1] @ mul scale\n"
"vmla.f32 q6, q10, %q[scale0] @ mul scale\n"
"vmla.f32 q7, q11, %q[scale1] @ mul scale\n"
"vmax.s32 q12, q12, %q[vzero] @ relu\n"
"vmax.s32 q13, q13, %q[vzero] @ relu\n"
"vmax.s32 q14, q14, %q[vzero] @ relu\n"
"vmax.s32 q3, q3, %q[vzero] @ relu\n"
"vcvt.s32.f32 q8, q4 @ cvt to int32\n"
"vcvt.s32.f32 q9, q5 @ cvt to int32\n"
"vcvt.s32.f32 q10, q6 @ cvt to int32\n"
"vcvt.s32.f32 q11, q7 @ cvt to int32\n"
"vqmovn.s32 d8, q8 @ cnt to int16\n"
"vqmovn.s32 d10, q9 @ cnt to int16\n"
"vqmovn.s32 d12, q10 @ cnt to int16\n"
"vqmovn.s32 d14, q11 @ cnt to int16\n"
"vqmovn.s16 d16, q4 @ cnt to int8\n"
"vqmovn.s16 d17, q5 @ cnt to int8\n"
"vqmovn.s16 d18, q6 @ cnt to int8\n"
"vqmovn.s16 d19, q7 @ cnt to int8\n"
"vmov.f32 q15, #0.5 \n"
"vcvt.f32.s32 q4, q12 \n"
"vcvt.f32.s32 q5, q13 \n"
"vcvt.f32.s32 q6, q14 \n"
"vcvt.f32.s32 q7, q3 \n"
"vand.i32 q12, q15, q15 @ set offset, 0.5\n"
"vand.i32 q13, q15, q15 @ set offset, 0.5\n"
"vand.i32 q14, q15, q15 @ set offset, 0.5\n"
"vand.i32 q3, q15, q15 @ set offset, 0.5\n"
"vmov.f32 q15, #-0.5 \n"
"vcgt.f32 q10, q4, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q11, q5, %q[vzero] @ get mask > 0, in0\n"
"vbif.f32 q12, q15, q10 @ get right offset\n"
"vbif.f32 q13, q15, q11 @ get right offset\n"
"vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in0\n"
"vbif.f32 q14, q15, q10 @ get right offset\n"
"vbif.f32 q3, q15, q11 @ get right offset\n"
"vmla.f32 q12, q4, %q[scale0] @ mul scale\n"
"vmla.f32 q13, q5, %q[scale1] @ mul scale\n"
"vmla.f32 q14, q6, %q[scale0] @ mul scale\n"
"vmla.f32 q3, q7, %q[scale1] @ mul scale\n"
"vcvt.s32.f32 q4, q12 @ cvt to int32\n"
"vcvt.s32.f32 q5, q13 @ cvt to int32\n"
"vcvt.s32.f32 q6, q14 @ cvt to int32\n"
"vcvt.s32.f32 q7, q3 @ cvt to int32\n"
"vqmovn.s32 d24, q4 @ cnt to int16\n"
"vqmovn.s32 d26, q5 @ cnt to int16\n"
"vqmovn.s32 d28, q6 @ cnt to int16\n"
"vqmovn.s32 d6, q7 @ cnt to int16\n"
"vqmovn.s16 d20, q12 @ cnt to int8\n"
"vqmovn.s16 d21, q13 @ cnt to int8\n"
"vqmovn.s16 d22, q14 @ cnt to int8\n"
"vqmovn.s16 d23, q3 @ cnt to int8\n"
"vtrn.8 d16, d18 @ trans q0, q2 \n"
"vtrn.8 d20, d22 @ trans q4, q6 \n"
"vtrn.16 d16, d20 @ trans q0, q2 \n"
"vtrn.16 d18, d22 @ trans q4, q6 \n"
"vtrn.8 d17, d19 @ trans q0, q2 \n"
"vtrn.8 d21, d23 @ trans q4, q6 \n"
"vtrn.16 d17, d21 @ trans q0, q2 \n"
"vtrn.16 d19, d23 @ trans q4, q6 \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"vst1.32 {d16[0]}, [%[doutc0r0]] @ store result, add "
"pointer\n"
"vst1.32 {d18[0]}, [%[doutc1r0]] @ store result, add "
"pointer\n"
"vst1.32 {d20[0]}, [%[doutc2r0]] @ store result, add "
"pointer\n"
"vst1.32 {d22[0]}, [%[doutc3r0]] @ store result, add "
"pointer\n"
"vst1.32 {d17[0]}, [%[doutc4r0]] @ store result, add "
"pointer\n"
"vst1.32 {d19[0]}, [%[doutc5r0]] @ store result, add "
"pointer\n"
"vst1.32 {d21[0]}, [%[doutc6r0]] @ store result, add "
"pointer\n"
"vst1.32 {d23[0]}, [%[doutc7r0]] @ store result, add "
"pointer\n"
"add %[doutc0r0], #4 @ add \n"
"add %[doutc1r0], #4 @ add \n"
"add %[doutc2r0], #4 @ add \n"
"add %[doutc3r0], #4 @ add \n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"add %[doutc4r0], #4 @ add \n"
"add %[doutc5r0], #4 @ add \n"
"add %[doutc6r0], #4 @ add \n"
"add %[doutc7r0], #4 @ add \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
[scale0] "w"(w_scale0), [scale1] "w"(w_scale1), [vzero] "w"(vzero)
: "q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
if (we > width) {
int offset = 32 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
for (; i < width; ++i) {
*(doutc0_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[0] * scale[0], 0)));
*(doutc1_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[1] * scale[1], 0)));
*(doutc2_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[2] * scale[2], 0)));
*(doutc3_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[3] * scale[3], 0)));
*(doutc4_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[4] * scale[4], 0)));
*(doutc5_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[5] * scale[5], 0)));
*(doutc6_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[6] * scale[6], 0)));
*(doutc7_ptr++) = saturate_cast<signed char>(
roundf(LITEMAX(din_hei_ptr[7] * scale[7], 0)));
din_hei_ptr += 8;
}
}
}
} else {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
dtype* doutc1_ptr = doutc1r0 + size_w;
dtype* doutc2_ptr = doutc2r0 + size_w;
dtype* doutc3_ptr = doutc3r0 + size_w;
dtype* doutc4_ptr = doutc4r0 + size_w;
dtype* doutc5_ptr = doutc5r0 + size_w;
dtype* doutc6_ptr = doutc6r0 + size_w;
dtype* doutc7_ptr = doutc7r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 7:
doutc1_ptr = trash_ptr;
case 6:
doutc2_ptr = trash_ptr;
case 5:
doutc3_ptr = trash_ptr;
case 4:
doutc4_ptr = trash_ptr;
case 3:
doutc5_ptr = trash_ptr;
case 2:
doutc6_ptr = trash_ptr;
case 1:
doutc7_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const int* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
// "movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/
"trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/
"trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/
"trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/
"trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/
"trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/
"trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
// int32 --> fp32
"scvtf v10.4s, v16.4s \n"
"scvtf v11.4s, v17.4s \n"
"scvtf v14.4s, v18.4s \n"
"scvtf v15.4s, v19.4s \n"
"scvtf v20.4s, v8.4s \n"
"scvtf v21.4s, v9.4s \n"
"scvtf v22.4s, v12.4s \n"
"scvtf v23.4s, v13.4s \n"
// mul
"fmul v16.4s, v10.4s, %[scale0].s[0] \n"
"fmul v17.4s, v11.4s, %[scale0].s[2] \n"
"fmul v18.4s, v14.4s, %[scale0].s[1] \n"
"fmul v19.4s, v15.4s, %[scale0].s[3] \n"
"fmul v8.4s, v20.4s, %[scale1].s[0] \n"
"fmul v9.4s, v21.4s, %[scale1].s[2] \n"
"fmul v12.4s, v22.4s, %[scale1].s[1] \n"
"fmul v13.4s, v23.4s, %[scale1].s[3] \n"
// fp32-int32
"fcvtas v10.4s, v16.4s \n"
"fcvtas v11.4s, v17.4s \n"
"fcvtas v14.4s, v18.4s \n"
"fcvtas v15.4s, v19.4s \n"
"fcvtas v20.4s, v8.4s \n"
"fcvtas v21.4s, v9.4s \n"
"fcvtas v22.4s, v12.4s \n"
"fcvtas v23.4s, v13.4s \n"
// int32-int16
"sqxtn v16.4h, v10.4s \n"
"sqxtn v17.4h, v11.4s \n"
"sqxtn v18.4h, v14.4s \n"
"sqxtn v19.4h, v15.4s \n"
"sqxtn v8.4h, v20.4s \n"
"sqxtn v9.4h, v21.4s \n"
"sqxtn v12.4h, v22.4s \n"
"sqxtn v13.4h, v23.4s \n"
// int16-int8
"sqxtn v10.8b, v16.8h \n"
"sqxtn v11.8b, v17.8h \n"
"sqxtn v14.8b, v18.8h \n"
"sqxtn v15.8b, v19.8h \n"
"sqxtn v20.8b, v8.8h \n"
"sqxtn v21.8b, v9.8h \n"
"sqxtn v22.8b, v12.8h \n"
"sqxtn v23.8b, v13.8h \n"
"str s10, [%[doutc0r0]], #4 \n" /* store c0r0*/
"str s11, [%[doutc2r0]], #4 \n" /* store c2r0*/
"str s14, [%[doutc1r0]], #4 \n" /* store c1r0*/
"str s15, [%[doutc3r0]], #4 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str s20, [%[doutc4r0]], #4 \n" /* store c0r0*/
"str s21, [%[doutc6r0]], #4 \n" /* store c2r0*/
"str s22, [%[doutc5r0]], #4 \n" /* store c1r0*/
"str s23, [%[doutc7r0]], #4 \n" /* store c3r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
: [scale0] "w"(w_scale0), [scale1] "w"(w_scale1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23");
#else
asm volatile(
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"1: @ main loop\n"
// int32-> fp32
"vmov.f32 q15, #0.5 \n"
"vcvt.f32.s32 q8, q4 \n"
"vcvt.f32.s32 q9, q5 \n"
"vcvt.f32.s32 q10, q6 \n"
"vcvt.f32.s32 q11, q7 \n"
"vand.i32 q4, q15, q15 @ set offset, 0.5\n"
"vand.i32 q5, q4, q4 @ set offset, 0.5\n"
"vand.i32 q6, q4, q4 @ set offset, 0.5\n"
"vand.i32 q7, q4, q4 @ set offset, 0.5\n"
"vmov.f32 q15, #-0.5 \n"
"vcgt.f32 q12, q8, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q13, q9, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q14, q10, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q3, q11, %q[vzero] @ get mask > 0, in0\n"
"vbif.f32 q4, q15, q12 @ get right offset\n"
"vbif.f32 q5, q15, q13 @ get right offset\n"
"vbif.f32 q6, q15, q14 @ get right offset\n"
"vbif.f32 q7, q15, q3 @ get right offset\n"
"vld1.32 {d24-d27}, [%[ptr_din]]! @load data \n"
"vld1.32 {d28-d29}, [%[ptr_din]]! @load data \n"
"vld1.32 {d6-d7}, [%[ptr_din]]! @load data \n"
"vmla.f32 q4, q8, %q[scale0] @ mul scale\n"
"vmla.f32 q5, q9, %q[scale1] @ mul scale\n"
"vmla.f32 q6, q10, %q[scale0] @ mul scale\n"
"vmla.f32 q7, q11, %q[scale1] @ mul scale\n"
"vcvt.s32.f32 q8, q4 @ cvt to int32\n"
"vcvt.s32.f32 q9, q5 @ cvt to int32\n"
"vcvt.s32.f32 q10, q6 @ cvt to int32\n"
"vcvt.s32.f32 q11, q7 @ cvt to int32\n"
"vqmovn.s32 d8, q8 @ cnt to int16\n"
"vqmovn.s32 d10, q9 @ cnt to int16\n"
"vqmovn.s32 d12, q10 @ cnt to int16\n"
"vqmovn.s32 d14, q11 @ cnt to int16\n"
"vqmovn.s16 d16, q4 @ cnt to int8\n"
"vqmovn.s16 d17, q5 @ cnt to int8\n"
"vqmovn.s16 d18, q6 @ cnt to int8\n"
"vqmovn.s16 d19, q7 @ cnt to int8\n"
"vmov.f32 q15, #0.5 \n"
"vcvt.f32.s32 q4, q12 \n"
"vcvt.f32.s32 q5, q13 \n"
"vcvt.f32.s32 q6, q14 \n"
"vcvt.f32.s32 q7, q3 \n"
"vand.i32 q12, q15, q15 @ set offset, 0.5\n"
"vand.i32 q13, q12, q12 @ set offset, 0.5\n"
"vand.i32 q14, q12, q12 @ set offset, 0.5\n"
"vand.i32 q3, q12, q12 @ set offset, 0.5\n"
"vmov.f32 q15, #-0.5 \n"
"vcgt.f32 q10, q4, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q11, q5, %q[vzero] @ get mask > 0, in0\n"
"vbif.f32 q12, q15, q10 @ get right offset\n"
"vbif.f32 q13, q15, q11 @ get right offset\n"
"vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in0\n"
"vbif.f32 q14, q15, q10 @ get right offset\n"
"vbif.f32 q3, q15, q11 @ get right offset\n"
"vmla.f32 q12, q4, %q[scale0] @ mul scale\n"
"vmla.f32 q13, q5, %q[scale1] @ mul scale\n"
"vmla.f32 q14, q6, %q[scale0] @ mul scale\n"
"vmla.f32 q3, q7, %q[scale1] @ mul scale\n"
"vcvt.s32.f32 q4, q12 @ cvt to int32\n"
"vcvt.s32.f32 q5, q13 @ cvt to int32\n"
"vcvt.s32.f32 q6, q14 @ cvt to int32\n"
"vcvt.s32.f32 q7, q3 @ cvt to int32\n"
"vqmovn.s32 d24, q4 @ cnt to int16\n"
"vqmovn.s32 d26, q5 @ cnt to int16\n"
"vqmovn.s32 d28, q6 @ cnt to int16\n"
"vqmovn.s32 d6, q7 @ cnt to int16\n"
"vqmovn.s16 d20, q12 @ cnt to int8\n"
"vqmovn.s16 d21, q13 @ cnt to int8\n"
"vqmovn.s16 d22, q14 @ cnt to int8\n"
"vqmovn.s16 d23, q3 @ cnt to int8\n"
"vtrn.8 d16, d18 @ trans q0, q2 \n"
"vtrn.8 d20, d22 @ trans q4, q6 \n"
"vtrn.16 d16, d20 @ trans q0, q2 \n"
"vtrn.16 d18, d22 @ trans q4, q6 \n"
"vtrn.8 d17, d19 @ trans q0, q2 \n"
"vtrn.8 d21, d23 @ trans q4, q6 \n"
"vtrn.16 d17, d21 @ trans q0, q2 \n"
"vtrn.16 d19, d23 @ trans q4, q6 \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"vst1.32 {d16[0]}, [%[doutc0r0]] @ store result, add "
"pointer\n"
"vst1.32 {d18[0]}, [%[doutc1r0]] @ store result, add "
"pointer\n"
"vst1.32 {d20[0]}, [%[doutc2r0]] @ store result, add "
"pointer\n"
"vst1.32 {d22[0]}, [%[doutc3r0]] @ store result, add "
"pointer\n"
"vst1.32 {d17[0]}, [%[doutc4r0]] @ store result, add "
"pointer\n"
"vst1.32 {d19[0]}, [%[doutc5r0]] @ store result, add "
"pointer\n"
"vst1.32 {d21[0]}, [%[doutc6r0]] @ store result, add "
"pointer\n"
"vst1.32 {d23[0]}, [%[doutc7r0]] @ store result, add "
"pointer\n"
"add %[doutc0r0], #4 @ add \n"
"add %[doutc1r0], #4 @ add \n"
"add %[doutc2r0], #4 @ add \n"
"add %[doutc3r0], #4 @ add \n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"add %[doutc4r0], #4 @ add \n"
"add %[doutc5r0], #4 @ add \n"
"add %[doutc6r0], #4 @ add \n"
"add %[doutc7r0], #4 @ add \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
[scale0] "w"(w_scale0), [scale1] "w"(w_scale1), [vzero] "w"(vzero)
: "q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
if (we > width) {
int offset = 32 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
for (; i < width; ++i) {
*(doutc0_ptr++) =
saturate_cast<signed char>(roundf(din_hei_ptr[0] * scale[0]));
*(doutc1_ptr++) =
saturate_cast<signed char>(roundf(din_hei_ptr[1] * scale[1]));
*(doutc2_ptr++) =
saturate_cast<signed char>(roundf(din_hei_ptr[2] * scale[2]));
*(doutc3_ptr++) =
saturate_cast<signed char>(roundf(din_hei_ptr[3] * scale[3]));
*(doutc4_ptr++) =
saturate_cast<signed char>(roundf(din_hei_ptr[4] * scale[4]));
*(doutc5_ptr++) =
saturate_cast<signed char>(roundf(din_hei_ptr[5] * scale[5]));
*(doutc6_ptr++) =
saturate_cast<signed char>(roundf(din_hei_ptr[6] * scale[6]));
*(doutc7_ptr++) =
saturate_cast<signed char>(roundf(din_hei_ptr[7] * scale[7]));
din_hei_ptr += 8;
}
} }
din_hei_ptr += 8;
} }
} }
} else {
LOG(ERROR) << "ERROR: unsupported input data type!!";
return false;
} }
return true;
} }
/* /*
......
...@@ -16,20 +16,16 @@ ...@@ -16,20 +16,16 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
template <PrecisionType Ptype> void conv_3x3s1_depthwise_fp32(const float* i_data,
class DepthwiseConv
: public ImplBase<TARGET(kARM), Ptype, operators::ConvParam> {
public:
typedef void (*conv_dw_impl)(const float* i_data,
float* o_data, float* o_data,
int bs, int bs,
int oc, int oc,
...@@ -37,62 +33,125 @@ class DepthwiseConv ...@@ -37,62 +33,125 @@ class DepthwiseConv
int ow, int ow,
int ic, int ic,
int ih, int ih,
int kw, int win,
const float* w_data, const float* weights,
const float* b_data, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx); ARMContext* ctx);
DepthwiseConv() = default;
~DepthwiseConv() {}
virtual bool init(const operators::ConvParam& param, void conv_3x3s2_depthwise_fp32(const float* i_data,
Context<TARGET(kARM)>* ctx); float* o_data,
int bs,
virtual bool create(const operators::ConvParam& param, int oc,
Context<TARGET(kARM)>* ctx); int oh,
int ow,
virtual bool run(const operators::ConvParam& param); int ic,
int ih,
private: int win,
conv_dw_impl impl_{nullptr}; const float* weights,
}; const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
template <PrecisionType Ptype_out> void conv_depthwise_3x3p0_fp32(const float* din,
class DepthwiseConvInt8 float* dout,
: public ImplBase<TARGET(kARM), PRECISION(kInt8), operators::ConvParam> { int num,
public: int ch_out,
typedef void (*conv_dw_int8_impl)(const int8_t* i_data, int h_out,
int32_t* o_data, int w_out,
int bs, int ch_in,
int oc, int h_in,
int oh, int w_in,
int ow, const float* weights,
int ic, const float* bias,
int ih, int stride,
int kw, bool flag_bias,
const int8_t* w_data, bool flag_relu,
const int32_t* b_data, ARMContext* ctx);
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx,
PrecisionType out_type,
const float* scale);
DepthwiseConvInt8() = default; void conv_depthwise_3x3p1_fp32(const float* din,
~DepthwiseConvInt8() {} float* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const float* weights,
const float* bias,
int stride,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
virtual bool init(const operators::ConvParam& param, template <typename Dtype>
Context<TARGET(kARM)>* ctx); void conv_depthwise_3x3s1_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
virtual bool create(const operators::ConvParam& param, template <typename Dtype>
Context<TARGET(kARM)>* ctx); void conv_depthwise_3x3s2_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
virtual bool run(const operators::ConvParam& param); void conv_depthwise_5x5s1_fp32(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
private: void conv_depthwise_5x5s2_fp32(const float* din,
conv_dw_int8_impl impl_{nullptr}; float* dout,
std::vector<float> w_scale_; int num,
Tensor tmp_int32_out_; int chout,
}; int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
...@@ -128,21 +128,21 @@ void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, ...@@ -128,21 +128,21 @@ void conv_depthwise_3x3s2p0_bias_s_relu(float* dout,
const int w_out, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3p0(const float* din, void conv_depthwise_3x3p0_fp32(const float* din,
float* dout, float* dout,
int num, int num,
int ch_out, int ch_out,
int h_out, int h_out,
int w_out, int w_out,
int ch_in, int ch_in,
int h_in, int h_in,
int w_in, int w_in,
const float* weights, const float* weights,
const float* bias, const float* bias,
int stride, int stride,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
ARMContext* ctx) { ARMContext* ctx) {
if (stride == 1) { if (stride == 1) {
if (flag_relu) { if (flag_relu) {
if (w_in > 5) { if (w_in > 5) {
......
...@@ -128,21 +128,21 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, ...@@ -128,21 +128,21 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout,
const int w_out, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3p1(const float* din, void conv_depthwise_3x3p1_fp32(const float* din,
float* dout, float* dout,
int num, int num,
int ch_out, int ch_out,
int h_out, int h_out,
int w_out, int w_out,
int ch_in, int ch_in,
int h_in, int h_in,
int w_in, int w_in,
const float* weights, const float* weights,
const float* bias, const float* bias,
int stride, int stride,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
ARMContext* ctx) { ARMContext* ctx) {
if (stride == 1) { if (stride == 1) {
if (flag_relu) { if (flag_relu) {
if (w_in > 4) { if (w_in > 4) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/conv_direct.h"
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
bool DirectConv<PRECISION(kFloat)>::create(const operators::ConvParam& param,
ARMContext* ctx) {
this->ctx_ = ctx;
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ic = x_dims[1];
int ow = o_dims[3];
int oc = o_dims[1];
int kw = w_dims[3];
int sw = param.strides[1];
// select dw conv kernel
const auto* w_data = param.filter->data<float>();
if (kw == 3 && sw == 1) {
VLOG(5) << "invoke 3x3s1 direct conv";
impl_ = conv_3x3s1_direct_fp32;
constexpr int cblock = 4;
int cround = (oc + cblock - 1) / cblock * cblock;
weights_trans_.Resize({cround, ic, kw, kw});
float* transed_w_data = weights_trans_.mutable_data<float>();
conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw);
is_weights_transed_ = true;
} else if (kw == 3 && sw == 2) {
VLOG(5) << "invoke 3x3s2 direct conv";
impl_ = conv_3x3s2_direct_fp32;
constexpr int cblock = 4;
int cround = (oc + cblock - 1) / cblock * cblock;
weights_trans_.Resize({cround, ic, kw, kw});
float* transed_w_data = weights_trans_.mutable_data<float>();
conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw);
is_weights_transed_ = true;
} else {
LOG(ERROR) << "this type direct conv not impl";
return false;
}
return true;
}
template <>
bool DirectConv<PRECISION(kFloat)>::init(const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx) {
this->ctx_ = ctx;
return create(param, ctx);
}
template <>
bool DirectConv<PRECISION(kFloat)>::run(const operators::ConvParam& param) {
// start timer
const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>();
if (is_weights_transed_ == true) {
w_data = weights_trans_.data<float>();
}
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
impl_(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
this->ctx_);
// timer end
return true;
}
template <PrecisionType Ptype_out>
bool DirectConvInt8<Ptype_out>::create(const operators::ConvParam& param,
ARMContext* ctx) {
this->ctx_ = ctx;
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ic = x_dims[1];
int ow = o_dims[3];
int oc = o_dims[1];
int kw = w_dims[3];
int sw = param.strides[1];
// select dw conv kernel
w_scale_ = param.weight_scale;
//! update weights scale
const auto* w_data = param.filter->data<int8_t>();
if (Ptype_out == PRECISION(kInt8) || Ptype_out == PRECISION(kFloat)) {
CHECK_EQ(this->w_scale_.size(), oc) << "weights scale size must be chout";
float input_scale = param.input_scale;
for (auto& w_s : w_scale_) {
w_s *= input_scale;
if (Ptype_out == PRECISION(kInt8)) {
w_s /= param.output_scale;
}
}
}
if (kw == 3 && sw == 1) {
VLOG(5) << "invoke 3x3s1 direct conv";
impl_int8_ = conv_3x3s1_direct_int8;
constexpr int cblock = 4;
int inpad = 4;
int cround = (oc + cblock - 1) / cblock * cblock;
weights_trans_.Resize({cround, ic, kw, kw});
int8_t* transed_w_data = weights_trans_.mutable_data<int8_t>();
conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw);
int wout_round = ((ow + 3) / 4) * 4;
int win_round = wout_round * sw + inpad;
int row_out = 2;
int row_in = 4;
int tmp_size_out = wout_round * row_out * cblock;
int in_len = win_round * ic;
int tmp_size_in = row_in * in_len;
ctx_->ExtendWorkspace(ctx_->threads() * tmp_size_out +
(tmp_size_in + 3) / 4 * 4 + wout_round + win_round);
is_weights_transed_ = true;
} else if (kw == 3 && sw == 2) {
VLOG(5) << "invoke 3x3s2 direct conv";
impl_int8_ = conv_3x3s2_direct_int8;
// constexpr int cblock = 4;
int cblock = conv_3x3s2_direct_int8_c_num();
int cround = (oc + cblock - 1) / cblock * cblock;
weights_trans_.Resize({cround, ic, kw, kw});
int8_t* transed_w_data = weights_trans_.mutable_data<int8_t>();
conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw);
is_weights_transed_ = true;
} else {
LOG(ERROR) << "this type direct conv not impl";
return false;
}
return true;
}
template <PrecisionType Ptype_out>
bool DirectConvInt8<Ptype_out>::init(const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx) {
this->ctx_ = ctx;
return create(param, ctx);
}
template <PrecisionType Ptype_out>
bool DirectConvInt8<Ptype_out>::run(const operators::ConvParam& param) {
// start timer
const auto* i_data = param.x->data<int8_t>();
const auto* w_data = param.filter->data<int8_t>();
const auto* b_data = param.bias ? param.bias->data<int32_t>() : nullptr;
auto* o_data = param.output->mutable_data<int32_t>();
if (is_weights_transed_ == true) {
w_data = weights_trans_.data<int8_t>();
}
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
impl_int8_(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
this->ctx_,
Ptype_out,
w_scale_.data());
// Modified from int32 for debug convenience
if (Ptype_out == PRECISION(kInt8)) param.output->mutable_data<int8_t>();
return true;
}
template class DirectConvInt8<PRECISION(kInt8)>;
template class DirectConvInt8<PRECISION(kFloat)>;
template class DirectConvInt8<PRECISION(kInt32)>;
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <PrecisionType Ptype>
class DirectConv : public ImplBase<TARGET(kARM), Ptype, operators::ConvParam> {
public:
typedef void (*conv_direct_impl)(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx);
DirectConv() = default;
~DirectConv() {}
virtual bool init(const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx);
virtual bool create(const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx);
virtual bool run(const operators::ConvParam& param);
protected:
bool is_weights_transed_{false};
Tensor weights_trans_;
Tensor _tmp_out;
private:
conv_direct_impl impl_{nullptr};
};
template <PrecisionType Ptype_out>
class DirectConvInt8
: public ImplBase<TARGET(kARM), PRECISION(kInt8), operators::ConvParam> {
public:
typedef void (*conv_direct_int8_impl)(const int8_t* din,
int32_t* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int8_t* weights,
const int32_t* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx,
PrecisionType out_type,
const float* scale);
DirectConvInt8() = default;
~DirectConvInt8() {}
virtual bool init(const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx);
virtual bool create(const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx);
virtual bool run(const operators::ConvParam& param);
private:
bool is_weights_transed_{false};
Tensor weights_trans_;
Tensor _tmp_out;
conv_direct_int8_impl impl_int8_{nullptr};
std::vector<float> w_scale_;
};
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/conv_gemmlike.h"
#include <vector>
#include "lite/backends/arm/math/gemm_prepacked_int8.h"
#include "lite/backends/arm/math/packed_sgemm.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
/********************* Gemmlike Conv Precision Is Float ***********************/
template <>
bool GemmLikeConv<PRECISION(kFloat)>::create(const operators::ConvParam& param,
ARMContext* ctx) {
this->ctx_ = ctx;
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int kw = w_dims[3];
int kh = w_dims[2];
int sw = param.strides[1];
int sh = param.strides[0];
int pw = param.paddings[1];
int ph = param.paddings[0];
int dw = param.dilations[1];
int dh = param.dilations[0];
int m = oc / param.groups;
int k = ic * kh * kw / param.groups;
int n = oh * ow;
bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh);
bool ks_equal = (sw == sh) && (kw == kh);
//! select conv gemmlike kernel
if (kw == 1 && sw == 1 && pw == 0 && kps_equal) {
//! 1x1s1p0 gemmlike conv
impl_ = conv1x1s1_gemm;
} else {
//! otherwise case
if (kw == 3 && sw == 1 && n > 1 && ks_equal) {
idx_data_.Resize({1, 1, 1, n * kh * kw});
int* idx_out = idx_data_.mutable_data<int>();
for (int i = 0; i < oh; ++i) {
for (int j = 0; j < ow; ++j) {
compute_offset(idx_out, i, j, kh, kw, ih, iw, ph, pw, dh, dw);
idx_out += kh * kw;
}
}
}
//! im2col gemmlike conv
impl_ = conv_im2col_gemm;
this->ctx_->ExtendWorkspace(k * n * sizeof(float));
}
if (n > 1) {
int hblock = get_hblock(this->ctx_->arch());
int m_roundup = hblock * ((m + hblock - 1) / hblock);
int group_size_round_up = ((m_roundup * k + 15) / 16) * 16;
float* w_trans_ptr = nullptr;
weights_trans_.Resize({1, 1, 1, group_size_round_up * param.groups});
w_trans_ptr = weights_trans_.mutable_data<float>();
const auto* w_data = param.filter->data<float>();
for (int g = 0; g < param.groups; ++g) {
const float* weights_group = w_data + g * m * k;
float* weights_trans_ptr = w_trans_ptr + g * group_size_round_up;
prepackA(weights_trans_ptr,
weights_group,
1.f,
k,
0,
m,
0,
k,
false,
this->ctx_);
}
is_weights_transed_ = true;
}
return true;
}
template <>
bool GemmLikeConv<PRECISION(kFloat)>::init(const operators::ConvParam& param,
ARMContext* ctx) {
this->ctx_ = ctx;
return create(param, ctx);
}
template <>
bool GemmLikeConv<PRECISION(kFloat)>::run(const operators::ConvParam& param) {
// start timer
const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>();
const int* idx_data = idx_data_.mutable_data<int>();
if (is_weights_transed_) {
w_data = weights_trans_.data<float>();
}
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
impl_(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
this->ctx_,
idx_data);
// timer end
return true;
}
/********************* Gemmlike Conv Precision Is Int8 ************************/
template <PrecisionType Ptype_out>
bool GemmLikeConvInt8<Ptype_out>::create(const operators::ConvParam& param,
ARMContext* ctx) {
this->ctx_ = ctx;
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int kw = w_dims[3];
int kh = w_dims[2];
int sw = param.strides[1];
int sh = param.strides[0];
int pw = param.paddings[1];
int ph = param.paddings[0];
int dw = param.dilations[1];
int dh = param.dilations[0];
int m = oc / param.groups;
int k = ic * kh * kw / param.groups;
int n = oh * ow;
w_scale_ = param.weight_scale;
//! update weights scale
if (Ptype_out == PRECISION(kInt8) || Ptype_out == PRECISION(kFloat)) {
CHECK_EQ(this->w_scale_.size(), oc) << "weights scale size must be chout";
float input_scale = param.input_scale;
for (auto& w_s : w_scale_) {
w_s *= input_scale;
if (Ptype_out == PRECISION(kInt8)) {
w_s /= param.output_scale;
}
}
}
bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh);
bool ks_equal = (sw == sh) && (kw == kh);
//! select conv gemmlike kernel
if (kw == 1 && sw == 1 && pw == 0 && kps_equal) {
//! 1x1s1p0 gemmlike conv
impl_int8_ = conv1x1s1_gemm_int8;
} else {
//! otherwise case
if (kw == 3 && sw == 1 && n > 1 && ks_equal) {
idx_data_.Resize({1, 1, 1, n * kh * kw});
int* idx_out = idx_data_.mutable_data<int>();
for (int i = 0; i < oh; ++i) {
for (int j = 0; j < ow; ++j) {
compute_offset(idx_out, i, j, kh, kw, ih, iw, ph, pw, dh, dw);
idx_out += kh * kw;
}
}
}
//! im2col gemmlike conv
impl_int8_ = conv_im2col_gemm_int8;
this->ctx_->ExtendWorkspace(k * n);
}
if (n > 1) {
prepackA_int8(&this->weights_trans_,
*param.filter,
m,
k,
param.groups,
false,
this->ctx_);
this->is_weights_transed_ = true;
}
return true;
}
template <PrecisionType Ptype_out>
bool GemmLikeConvInt8<Ptype_out>::init(const operators::ConvParam& param,
ARMContext* ctx) {
this->ctx_ = ctx;
return create(param, ctx);
}
template <PrecisionType Ptype_out>
bool GemmLikeConvInt8<Ptype_out>::run(const operators::ConvParam& param) {
const auto* i_data = param.x->data<int8_t>();
const auto* w_data = param.filter->data<int8_t>();
const auto* b_data = param.bias ? param.bias->data<int32_t>() : nullptr;
auto* o_data = param.output->mutable_data<int32_t>();
const int32_t* idx_data = idx_data_.mutable_data<int32_t>();
if (this->is_weights_transed_ == true) {
w_data = this->weights_trans_.template data<int8_t>();
}
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
impl_int8_(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
this->ctx_,
Ptype_out,
this->w_scale_.data(),
idx_data);
return true;
}
template class GemmLikeConvInt8<PRECISION(kInt8)>;
template class GemmLikeConvInt8<PRECISION(kFloat)>;
template class GemmLikeConvInt8<PRECISION(kInt32)>;
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <PrecisionType Ptype>
class GemmLikeConv
: public ImplBase<TARGET(kARM), Ptype, operators::ConvParam> {
public:
typedef void (*conv_im2col_gemm_impl)(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const int* idx_ptr);
GemmLikeConv() = default;
~GemmLikeConv() {}
virtual bool init(const operators::ConvParam& param, ARMContext* ctx) {
LOG(FATAL) << "GemmLikeConv::init() not implemented.";
}
virtual bool create(const operators::ConvParam& param, ARMContext* ctx) {
LOG(FATAL) << "GemmLikeConv::create() not implemented.";
}
virtual bool run(const operators::ConvParam& param) {
LOG(FATAL) << "GemmLikeConv::run() not implemented.";
}
protected:
bool is_weights_transed_{false};
Tensor idx_data_;
Tensor weights_trans_;
private:
conv_im2col_gemm_impl impl_{nullptr};
};
template <PrecisionType Ptype_out>
class GemmLikeConvInt8 : public GemmLikeConv<PRECISION(kInt8)> {
public:
typedef void (*conv_im2col_gemm_int8_impl)(const int8_t* din,
int32_t* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int8_t* weights,
const int32_t* bias,
const operators::ConvParam& param,
ARMContext* ctx,
PrecisionType out_type,
const float* scale,
const int* idx_ptr);
GemmLikeConvInt8() = default;
~GemmLikeConvInt8() {}
virtual bool init(const operators::ConvParam& param, ARMContext* ctx);
virtual bool create(const operators::ConvParam& param, ARMContext* ctx);
virtual bool run(const operators::ConvParam& param);
private:
conv_im2col_gemm_int8_impl impl_int8_{nullptr};
std::vector<float> w_scale_;
};
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -12,14 +12,9 @@ ...@@ -12,14 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// #include "saber/funcs/impl/arm/neon/impl/conv_arm_depthwise.h"
// #include "saber/funcs/impl/arm/neon/impl/conv_arm_impl.h"
// #include "saber/funcs/impl/arm/neon/impl/gemm_prepacked_int8.h"
// #include "saber/funcs/impl/arm/neon/impl/gemv_arm_int8.h"
// #include "saber/funcs/impl/arm/neon/impl/sgemv_arm.h"
#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_impl.h"
#include <arm_neon.h> #include <arm_neon.h>
#include "lite/backends/arm/math/conv_depthwise.h"
#include "lite/backends/arm/math/gemm_prepacked_int8.h" #include "lite/backends/arm/math/gemm_prepacked_int8.h"
#include "lite/backends/arm/math/gemv_arm_int8.h" #include "lite/backends/arm/math/gemv_arm_int8.h"
#include "lite/backends/arm/math/packed_sgemm.h" #include "lite/backends/arm/math/packed_sgemm.h"
...@@ -107,17 +102,17 @@ inline bool is_a_ge_zero_and_a_lt_b(int a, int b) { ...@@ -107,17 +102,17 @@ inline bool is_a_ge_zero_and_a_lt_b(int a, int b) {
*/ */
template <typename Dtype> template <typename Dtype>
void im2col(const Dtype* data_im, void im2col(const Dtype* data_im,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int kernel_h, int kernel_h,
const int kernel_w, int kernel_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dilation_h, int dilation_h,
const int dilation_w, int dilation_w,
Dtype* data_col) { Dtype* data_col) {
const int output_h = const int output_h =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
...@@ -150,121 +145,6 @@ void im2col(const Dtype* data_im, ...@@ -150,121 +145,6 @@ void im2col(const Dtype* data_im,
} }
} }
} }
void compute_offset(int* idx_out,
int h,
int w,
int kernel_h,
int kernel_w,
int height,
int width,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w) {
int idx_h[kernel_h]; // NOLINT
int idx_w[kernel_w]; // NOLINT
for (int i = 0; i < kernel_h; ++i) {
idx_h[i] = h - pad_h + i * dilation_h;
}
for (int i = 0; i < kernel_w; ++i) {
idx_w[i] = w - pad_w + i * dilation_w;
}
for (int k_h = 0; k_h < kernel_h; ++k_h) {
for (int k_w = 0; k_w < kernel_w; ++k_w) {
idx_out[k_h * kernel_w + k_w] =
(idx_h[k_h] >= 0 && idx_w[k_w] >= 0 && idx_h[k_h] < height &&
idx_w[k_w] < width)
? idx_h[k_h] * width + idx_w[k_w]
: -1;
}
}
}
template <typename Dtype>
void im2col3x3(const Dtype* data_im,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
Dtype* data_col,
const int* idx) {
const int output_h =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int output_w =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
int kernel_stride = kernel_h * kernel_w;
int in_channel_stride = height * width;
const int* idx_out = idx;
Dtype* data_col_ptr = data_col;
bool flag_continue = false;
if (dilation_h == 1 && dilation_w == 1) {
flag_continue = true;
}
for (int o = 0; o < output_h * output_w; o += 1) {
const Dtype* data_im_ptr = data_im;
// int* idx_out_d = idx_out;
int idx_out_d0 = idx_out[0];
int idx_out_d1 = idx_out[1];
int idx_out_d2 = idx_out[2];
int idx_out_d3 = idx_out[3];
int idx_out_d4 = idx_out[4];
int idx_out_d5 = idx_out[5];
int idx_out_d6 = idx_out[6];
int idx_out_d7 = idx_out[7];
int idx_out_d8 = idx_out[8];
for (int i = 0; i < channels; i += 1) {
if (idx_out_d0 >= 0 && idx_out_d2 >= 0 && idx_out_d6 >= 0 &&
idx_out_d8 >= 0) {
if (flag_continue) {
memcpy(
data_col_ptr, data_im_ptr + idx_out_d0, kernel_w * sizeof(Dtype));
memcpy(data_col_ptr + kernel_w,
data_im_ptr + idx_out_d3,
kernel_w * sizeof(Dtype));
memcpy(data_col_ptr + kernel_w + kernel_w,
data_im_ptr + idx_out_d6,
kernel_w * sizeof(Dtype));
} else {
data_col_ptr[0] = data_im_ptr[idx_out_d0];
data_col_ptr[1] = data_im_ptr[idx_out_d1];
data_col_ptr[2] = data_im_ptr[idx_out_d2];
data_col_ptr[3] = data_im_ptr[idx_out_d3];
data_col_ptr[4] = data_im_ptr[idx_out_d4];
data_col_ptr[5] = data_im_ptr[idx_out_d5];
data_col_ptr[6] = data_im_ptr[idx_out_d6];
data_col_ptr[7] = data_im_ptr[idx_out_d7];
data_col_ptr[8] = data_im_ptr[idx_out_d8];
}
} else {
data_col_ptr[0] = (idx_out_d0 < 0) ? 0 : data_im_ptr[idx_out_d0];
data_col_ptr[1] = (idx_out_d1 < 0) ? 0 : data_im_ptr[idx_out_d1];
data_col_ptr[2] = (idx_out_d2 < 0) ? 0 : data_im_ptr[idx_out_d2];
data_col_ptr[3] = (idx_out_d3 < 0) ? 0 : data_im_ptr[idx_out_d3];
data_col_ptr[4] = (idx_out_d4 < 0) ? 0 : data_im_ptr[idx_out_d4];
data_col_ptr[5] = (idx_out_d5 < 0) ? 0 : data_im_ptr[idx_out_d5];
data_col_ptr[6] = (idx_out_d6 < 0) ? 0 : data_im_ptr[idx_out_d6];
data_col_ptr[7] = (idx_out_d7 < 0) ? 0 : data_im_ptr[idx_out_d7];
data_col_ptr[8] = (idx_out_d8 < 0) ? 0 : data_im_ptr[idx_out_d8];
}
data_im_ptr += height * width;
data_col_ptr += kernel_stride;
}
// data_col_ptr += channels * kernel_stride;
// idx_out += kernel_stride * 2;
idx_out += kernel_stride;
}
}
/** /**
* \brief convolution function for kernel size 1x1, stride size 1, gemm * \brief convolution function for kernel size 1x1, stride size 1, gemm
...@@ -282,8 +162,7 @@ void conv1x1s1_gemm(const float* i_data, ...@@ -282,8 +162,7 @@ void conv1x1s1_gemm(const float* i_data,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx, ARMContext* ctx) {
const int* idx_ptr) {
int channel_size_out = ow * oh; int channel_size_out = ow * oh;
int channel_size_in = win * ih; int channel_size_in = win * ih;
...@@ -294,21 +173,14 @@ void conv1x1s1_gemm(const float* i_data, ...@@ -294,21 +173,14 @@ void conv1x1s1_gemm(const float* i_data,
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
// if (param.activation_param.has_active) {
// if (param.activation_param.active == Active_relu && int hblock = get_hblock(ctx);
// fabs(param.activation_param.negative_slope) < 1e-6f) {
// flag_relu = true;
// }
// }
int hblock = get_hblock(ctx->arch());
int m_roundup = hblock * ((m + hblock - 1) / hblock); int m_roundup = hblock * ((m + hblock - 1) / hblock);
int weights_size_per_group = m * k; int weights_size_per_group = m * k;
if (n > 1) { if (n > 1) {
weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; weights_size_per_group = ((m_roundup * k + 15) / 16) * 16;
} }
// int weights_size_per_group = m_roundup * k;//oc * ic / (group *
// group);
//! use gemv when the output channel size = 1 //! use gemv when the output channel size = 1
for (int b = 0; b < num; ++b) { for (int b = 0; b < num; ++b) {
// dC // dC
...@@ -351,8 +223,9 @@ void conv1x1s1_gemm(const float* i_data, ...@@ -351,8 +223,9 @@ void conv1x1s1_gemm(const float* i_data,
} }
} }
template <typename Dtype>
void conv1x1s1_gemm_int8(const int8_t* i_data, void conv1x1s1_gemm_int8(const int8_t* i_data,
int32_t* o_data, Dtype* o_data,
int num, int num,
int oc, int oc,
int oh, int oh,
...@@ -361,12 +234,10 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, ...@@ -361,12 +234,10 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
int ih, int ih,
int win, int win,
const int8_t* weights, const int8_t* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx, ARMContext* ctx,
PrecisionType out_type, const float* scale) {
const float* scale,
const int32_t* idx_ptr) {
int group = param.groups; int group = param.groups;
int channel_size_out = ow * oh; int channel_size_out = ow * oh;
int channel_size_in = win * ih; int channel_size_in = win * ih;
...@@ -386,94 +257,71 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, ...@@ -386,94 +257,71 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
for (int b = 0; b < num; ++b) { for (int b = 0; b < num; ++b) {
// dC // dC
for (int g = 0; g < group; ++g) { for (int g = 0; g < group; ++g) {
signed char* dout_group = Dtype* dout_group = o_data + (b * oc + g * m) * channel_size_out;
reinterpret_cast<signed char*>(o_data) +
(b * oc + g * m) * channel_size_out * PrecisionTypeLength(out_type);
const int8_t* din_group = i_data + (b * ic + g * k) * channel_size_in; const int8_t* din_group = i_data + (b * ic + g * k) * channel_size_in;
const int8_t* weights_group = weights + g * weights_size_per_group; const int8_t* weights_group = weights + g * weights_size_per_group;
const int* bias_group = bias + g * m; const float* bias_group = bias + g * m;
const float* scale_group = scale + g * m; const float* scale_group = scale + g * m;
if (n == 1) { if (n == 1) {
if (out_type == PRECISION(kFloat)) { gemv_int8(weights_group,
gemv_int8(weights_group, din_group,
din_group, dout_group,
reinterpret_cast<float*>(dout_group), false,
false, m,
m, k,
k, scale_group,
scale_group, flag_bias,
flag_bias, bias_group,
bias_group, flag_relu,
flag_relu); ctx);
} else if (out_type == PRECISION(kInt8)) { // int8
gemv_int8(weights_group,
din_group,
dout_group,
false,
m,
k,
scale_group,
flag_bias,
bias_group,
flag_relu);
} else {
gemv_int8(weights_group,
din_group,
reinterpret_cast<int*>(dout_group),
false,
m,
k,
scale_group,
flag_bias,
bias_group,
flag_relu);
}
} else { } else {
if (out_type == PRECISION(kFloat)) { gemm_prepack_int8(weights_group,
gemm_prepack_int8(weights_group, din_group,
din_group, bias_group,
bias_group, dout_group,
reinterpret_cast<float*>(dout_group), m,
m, n,
n, k,
k, flag_bias,
flag_bias, flag_relu,
flag_relu, false,
false, scale_group,
scale_group, ctx);
ctx);
} else if (out_type == PRECISION(kInt8)) { // int8
gemm_prepack_int8(weights_group,
din_group,
bias_group,
dout_group,
m,
n,
k,
flag_bias,
flag_relu,
false,
scale_group,
ctx);
} else {
gemm_prepack_int8(weights_group,
din_group,
bias_group,
reinterpret_cast<int*>(dout_group),
m,
n,
k,
flag_bias,
flag_relu,
false,
scale_group,
ctx);
}
} }
} }
} }
} }
template void conv1x1s1_gemm_int8<int8_t>(const int8_t* i_data,
int8_t* o_data,
int num,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const int8_t* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale);
template void conv1x1s1_gemm_int8<float>(const int8_t* i_data,
float* o_data,
int num,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const int8_t* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale);
/** /**
* \brief convolution function for kernel size 3x3, stride size 2, gemm * \brief convolution function for kernel size 3x3, stride size 2, gemm
* implementation * implementation
...@@ -490,8 +338,7 @@ void conv_im2col_gemm(const float* i_data, ...@@ -490,8 +338,7 @@ void conv_im2col_gemm(const float* i_data,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx, ARMContext* ctx) {
const int* idx_ptr) {
const int group = param.groups; const int group = param.groups;
auto filter_dims = param.filter->dims(); auto filter_dims = param.filter->dims();
const int kernel_h = filter_dims[2]; const int kernel_h = filter_dims[2];
...@@ -504,22 +351,13 @@ void conv_im2col_gemm(const float* i_data, ...@@ -504,22 +351,13 @@ void conv_im2col_gemm(const float* i_data,
int channel_size_in = win * ih; int channel_size_in = win * ih;
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
// if (param.activation_param.has_active) { int hblock = get_hblock(ctx);
// if (param.activation_param.active == Active_relu &&
// fabs(param.activation_param.negative_slope) < 1e-6f) {
// flag_relu = true;
// }
// }
int hblock = get_hblock(ctx->arch());
int m_roundup = hblock * ((m + hblock - 1) / hblock); int m_roundup = hblock * ((m + hblock - 1) / hblock);
int weights_size_per_group = m * k; int weights_size_per_group = m * k;
if (n > 1) { if (n > 1) {
weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; weights_size_per_group = ((m_roundup * k + 15) / 16) * 16;
} }
bool flag_im2col2 = (kernel_h == 3 && kernel_w == 3 &&
param.strides[0] == 1 && param.strides[1] == 1 && n > 1);
float* tmp_work_space = float* tmp_work_space =
ctx->workspace_data<float>() + ctx->llc_size() / sizeof(float); ctx->workspace_data<float>() + ctx->llc_size() / sizeof(float);
...@@ -534,36 +372,20 @@ void conv_im2col_gemm(const float* i_data, ...@@ -534,36 +372,20 @@ void conv_im2col_gemm(const float* i_data,
const float* bias_group = bias + g * m; const float* bias_group = bias + g * m;
float* dB = tmp_work_space; float* dB = tmp_work_space;
if (flag_im2col2) { im2col(din_group,
im2col3x3(din_group, chin_per_group,
chin_per_group, ih,
ih, win,
win, kernel_h,
kernel_h, kernel_w,
kernel_w, param.paddings[0],
param.paddings[0], param.paddings[1],
param.paddings[1], param.strides[0],
param.strides[0], param.strides[1],
param.strides[1], param.dilations[0],
param.dilations[0], param.dilations[1],
param.dilations[1], dB);
dB,
idx_ptr);
} else {
im2col(din_group,
chin_per_group,
ih,
win,
kernel_h,
kernel_w,
param.paddings[0],
param.paddings[1],
param.strides[0],
param.strides[1],
param.dilations[0],
param.dilations[1],
dB);
}
if (n == 1) { if (n == 1) {
sgemv(weights_group, sgemv(weights_group,
dB, dB,
...@@ -576,10 +398,7 @@ void conv_im2col_gemm(const float* i_data, ...@@ -576,10 +398,7 @@ void conv_im2col_gemm(const float* i_data,
flag_relu); flag_relu);
} else { } else {
int ldb = n; int ldb = n;
if (flag_im2col2) { sgemm_prepack(false,
ldb = k;
}
sgemm_prepack(flag_im2col2,
m, m,
n, n,
k, k,
...@@ -598,8 +417,9 @@ void conv_im2col_gemm(const float* i_data, ...@@ -598,8 +417,9 @@ void conv_im2col_gemm(const float* i_data,
} }
} }
template <typename Dtype>
void conv_im2col_gemm_int8(const int8_t* i_data, void conv_im2col_gemm_int8(const int8_t* i_data,
int32_t* o_data, Dtype* o_data,
int num, int num,
int oc, int oc,
int oh, int oh,
...@@ -608,12 +428,10 @@ void conv_im2col_gemm_int8(const int8_t* i_data, ...@@ -608,12 +428,10 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
int ih, int ih,
int win, int win,
const int8_t* weights, const int8_t* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx, ARMContext* ctx,
PrecisionType out_type, const float* scale) {
const float* scale,
const int32_t* idx_ptr) {
int group = param.groups; int group = param.groups;
auto filter_dims = param.filter->dims(); auto filter_dims = param.filter->dims();
int kernel_h = filter_dims[2]; int kernel_h = filter_dims[2];
...@@ -641,9 +459,6 @@ void conv_im2col_gemm_int8(const int8_t* i_data, ...@@ -641,9 +459,6 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
weights_size_per_group = ((m_roundup * k_roundup + 15) / 16) * 16; weights_size_per_group = ((m_roundup * k_roundup + 15) / 16) * 16;
} }
bool flag_im2col2 = (kernel_h == 3 && kernel_w == 3 && stride_h == 1 &&
stride_w == 1 && n > 1);
int8_t* tmp_work_space = int8_t* tmp_work_space =
ctx->workspace_data<int8_t>() + ctx->llc_size() / sizeof(int8_t); ctx->workspace_data<int8_t>() + ctx->llc_size() / sizeof(int8_t);
...@@ -651,249 +466,345 @@ void conv_im2col_gemm_int8(const int8_t* i_data, ...@@ -651,249 +466,345 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
for (int b = 0; b < num; ++b) { for (int b = 0; b < num; ++b) {
// dC // dC
for (int g = 0; g < group; ++g) { for (int g = 0; g < group; ++g) {
signed char* dout_group = Dtype* dout_group = o_data + (b * oc + g * m) * channel_size_out;
reinterpret_cast<signed char*>(o_data) +
(b * oc + g * m) * channel_size_out * PrecisionTypeLength(out_type);
const int8_t* din_group = static_cast<const int8_t*>(i_data) + const int8_t* din_group = static_cast<const int8_t*>(i_data) +
(b * ic + g * chin_per_group) * channel_size_in; (b * ic + g * chin_per_group) * channel_size_in;
const int8_t* weights_group = const int8_t* weights_group =
static_cast<const int8_t*>(weights) + g * weights_size_per_group; static_cast<const int8_t*>(weights) + g * weights_size_per_group;
const int* bias_group = static_cast<const int*>(bias) + g * m; const float* bias_group = bias + g * m;
int8_t* dB = tmp_work_space; int8_t* dB = tmp_work_space;
const float* scale_group = scale + g * m; const float* scale_group = scale + g * m;
if (flag_im2col2) { im2col(din_group,
im2col3x3(din_group, chin_per_group,
chin_per_group, ih,
ih, win,
win, kernel_h,
kernel_h, kernel_w,
kernel_w, pad_h,
pad_h, pad_w,
pad_w, stride_h,
stride_h, stride_w,
stride_w, dila_h,
dila_h, dila_w,
dila_w, dB);
dB,
idx_ptr);
} else {
im2col(din_group,
chin_per_group,
ih,
win,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dila_h,
dila_w,
dB);
}
if (n == 1) { if (n == 1) {
if (out_type == PRECISION(kFloat)) { gemv_int8(weights_group,
gemv_int8(weights_group, dB,
dB, dout_group,
reinterpret_cast<float*>(dout_group), false,
false, m,
m, k,
k, scale_group,
scale_group, flag_bias,
flag_bias, bias_group,
bias_group, flag_relu,
flag_relu); ctx);
} else if (out_type == PRECISION(kInt8)) { // int8
gemv_int8(weights_group,
dB,
dout_group,
false,
m,
k,
scale_group,
flag_bias,
bias_group,
flag_relu);
} else {
gemv_int8(weights_group,
dB,
reinterpret_cast<int*>(dout_group),
false,
m,
k,
scale_group,
flag_bias,
bias_group,
flag_relu);
}
} else { } else {
if (out_type == PRECISION(kFloat)) { gemm_prepack_int8(weights_group,
gemm_prepack_int8(weights_group, dB,
dB, bias_group,
bias_group, dout_group,
reinterpret_cast<float*>(dout_group), m,
m, n,
n, k,
k, flag_bias,
flag_bias, flag_relu,
flag_relu, false,
flag_im2col2, scale_group,
scale_group, ctx);
ctx);
} else if (out_type == PRECISION(kInt8)) { // int8
gemm_prepack_int8(weights_group,
dB,
bias_group,
dout_group,
m,
n,
k,
flag_bias,
flag_relu,
flag_im2col2,
scale_group,
ctx);
} else {
gemm_prepack_int8(weights_group,
dB,
bias_group,
reinterpret_cast<int*>(dout_group),
m,
n,
k,
flag_bias,
flag_relu,
flag_im2col2,
scale_group,
ctx);
}
} }
} }
} }
} }
void conv_depthwise_3x3(const float* i_data, template void conv_im2col_gemm_int8<int8_t>(const int8_t* i_data,
float* o_data, int8_t* o_data,
int num, int num,
int oc, int oc,
int oh, int oh,
int ow, int ow,
int ic, int ic,
int ih, int ih,
int win, int win,
const float* weights, const int8_t* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx) { ARMContext* ctx,
int pad = param.paddings[1]; const float* scale);
template void conv_im2col_gemm_int8<float>(const int8_t* i_data,
float* o_data,
int num,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const int8_t* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale);
void conv_depthwise_3x3_fp32(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int stride = param.strides[1]; int stride = param.strides[1];
bool flag_relu = param.fuse_relu; if (stride == 1) {
bool flag_bias = param.bias != nullptr; conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
// if (param.activation_param.has_active) { reinterpret_cast<float*>(dout),
// if (param.activation_param.active == Active_relu && num,
// fabs(param.activation_param.negative_slope) < 1e-6f) { ch_out,
// flag_relu = true; h_out,
// } w_out,
// } ch_in,
h_in,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
ctx);
} else if (stride == 2) {
conv_3x3s2_depthwise_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
h_out,
w_out,
ch_in,
h_in,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
ctx);
} else {
LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride << " unsupported";
}
#if 0
if (pad == 1) { if (pad == 1) {
conv_depthwise_3x3p1(i_data, conv_depthwise_3x3p1_fp32(reinterpret_cast<const float*>(din),
o_data, reinterpret_cast<float*>(dout),
num, num,
oc, ch_out,
oh, h_out,
ow, w_out,
ic, ch_in,
ih, h_in,
win, w_in,
weights, reinterpret_cast<const float*>(weights),
bias, bias,
stride, stride,
flag_bias, flag_bias,
flag_relu, flag_relu,
ctx); ctx);
} else if (pad == 0 && ih > 2) { } else if (pad == 0 && h_in > 2) {
conv_depthwise_3x3p0(i_data, conv_depthwise_3x3p0_fp32(reinterpret_cast<const float*>(din),
o_data, reinterpret_cast<float*>(dout),
num, num,
oc, ch_out,
oh, h_out,
ow, w_out,
ic, ch_in,
ih, h_in,
win, w_in,
weights, reinterpret_cast<const float*>(weights),
bias, bias,
stride, stride,
flag_bias, flag_bias,
flag_relu, flag_relu,
ctx); ctx);
} else { } else {
LOG(FATAL) << "unsupport this type 3x3 dw conv"; LOG(FATAL) << "unsupport this type 3x3 dw conv";
} }
#endif
} }
void conv_depthwise_5x5(const float* i_data, void conv_depthwise_5x5_fp32(const void* din,
float* o_data, void* dout,
int num, int num,
int oc, int ch_out,
int oh, int h_out,
int ow, int w_out,
int ic, int ch_in,
int ih, int h_in,
int win, int w_in,
const float* weights, const void* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx) { ARMContext* ctx,
const float* scale) {
int pad = param.paddings[1]; int pad = param.paddings[1];
int stride = param.strides[1]; int stride = param.strides[1];
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
// if (param.activation_param.has_active && ctx->ExtendWorkspace((w_in + w_out) * sizeof(float));
// fabs(param.activation_param.negative_slope) < 1e-6f) {
// if (param.activation_param.active == Active_relu) {
// flag_relu = true;
// }
// }
if (pad == 2 && stride == 2) { if (pad == 2 && stride == 2) {
conv_depthwise_5x5s2(i_data, conv_depthwise_5x5s2_fp32(reinterpret_cast<const float*>(din),
o_data, reinterpret_cast<float*>(dout),
num, num,
oc, ch_out,
oh, h_out,
ow, w_out,
ic, ch_in,
ih, h_in,
win, w_in,
weights, reinterpret_cast<const float*>(weights),
bias, bias,
pad, pad,
flag_bias, flag_bias,
flag_relu, flag_relu,
ctx); ctx);
} else if (stride == 1) { } else if (stride == 1) {
conv_depthwise_5x5s1(i_data, conv_depthwise_5x5s1_fp32(reinterpret_cast<const float*>(din),
o_data, reinterpret_cast<float*>(dout),
num, num,
oc, ch_out,
oh, h_out,
ow, w_out,
ic, ch_in,
ih, h_in,
win, w_in,
weights, reinterpret_cast<const float*>(weights),
bias, bias,
pad, pad,
flag_bias, flag_bias,
flag_relu, flag_relu,
ctx); ctx);
} else { } else {
LOG(FATAL) << "unsupport this type 5x5 dw conv"; LOG(FATAL) << "unsupport this type 5x5 dw conv";
} }
} }
void conv_depthwise_3x3_int8_fp32(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_3x3s1_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else if (stride == 2) {
conv_depthwise_3x3s2_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 3x3 dw conv int8";
}
}
void conv_depthwise_3x3_int8_int8(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_3x3s1_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else if (stride == 2) {
conv_depthwise_3x3s2_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 3x3 dw conv int8";
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -23,26 +23,6 @@ namespace lite { ...@@ -23,26 +23,6 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
// TODO(TJ): move to somewhere else common
template <TargetType TType, PrecisionType PType, typename Param>
class ImplBase {
public:
ImplBase() {}
virtual ~ImplBase() {}
virtual bool create(const Param& param, Context<TType>* ctx) { return false; }
virtual bool init(const Param& param, Context<TType>* ctx) { return false; }
virtual bool run(const Param& param) { return false; }
// void set_op_name(const char* name){_op_name = name;}
// const char* get_op_name() { return _op_name.c_str();}
protected:
Param* param_;
Context<TType>* ctx_;
};
void conv_3x3s1_direct_fp32(const float* din, void conv_3x3s1_direct_fp32(const float* din,
float* dout, float* dout,
int num, int num,
...@@ -55,26 +35,11 @@ void conv_3x3s1_direct_fp32(const float* din, ...@@ -55,26 +35,11 @@ void conv_3x3s1_direct_fp32(const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx); ARMContext* ctx);
template <typename Dtype>
void conv_3x3s1_direct_int8(const int8_t* din, void conv_3x3s1_direct_int8(const int8_t* din,
int32_t* dout, Dtype* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int8_t* weights,
const int32_t* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx,
PrecisionType out_type,
const float* scale);
void conv_3x3s1_direct_int7(const int8_t* din,
int32_t* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
...@@ -83,10 +48,9 @@ void conv_3x3s1_direct_int7(const int8_t* din, ...@@ -83,10 +48,9 @@ void conv_3x3s1_direct_int7(const int8_t* din,
int hin, int hin,
int win, int win,
const int8_t* weights, const int8_t* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, ARMContext* ctx,
PrecisionType out_type,
const float* scale); const float* scale);
void conv_3x3s2_direct_fp32(const float* din, void conv_3x3s2_direct_fp32(const float* din,
...@@ -101,12 +65,13 @@ void conv_3x3s2_direct_fp32(const float* din, ...@@ -101,12 +65,13 @@ void conv_3x3s2_direct_fp32(const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx); ARMContext* ctx);
int conv_3x3s2_direct_int8_c_num(); int conv_3x3s2_direct_int8_c_num();
template <typename Dtype>
void conv_3x3s2_direct_int8(const int8_t* din, void conv_3x3s2_direct_int8(const int8_t* din,
int32_t* dout, Dtype* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
...@@ -115,14 +80,13 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -115,14 +80,13 @@ void conv_3x3s2_direct_int8(const int8_t* din,
int hin, int hin,
int win, int win,
const int8_t* weights, const int8_t* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, ARMContext* ctx,
PrecisionType out_type,
const float* scale); const float* scale);
void conv_1x5s1_direct(const void* din, void conv_1x5s1_direct(const float* din,
void* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
...@@ -130,8 +94,8 @@ void conv_1x5s1_direct(const void* din, ...@@ -130,8 +94,8 @@ void conv_1x5s1_direct(const void* din,
int chin, int chin,
int hin, int hin,
int win, int win,
const void* weights, const float* weights,
const void* bias, const float* bias,
int group, int group,
int kernel_w, int kernel_w,
int kernel_h, int kernel_h,
...@@ -143,12 +107,10 @@ void conv_1x5s1_direct(const void* din, ...@@ -143,12 +107,10 @@ void conv_1x5s1_direct(const void* din,
int pad_h, int pad_h,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
Context<TARGET(kARM)>& ctx, ARMContext& ctx); // NOLINT
void* work_space,
const void* idx_ptr);
void conv_5x1s1_direct(const void* din, void conv_5x1s1_direct(const float* din,
void* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
...@@ -156,8 +118,8 @@ void conv_5x1s1_direct(const void* din, ...@@ -156,8 +118,8 @@ void conv_5x1s1_direct(const void* din,
int chin, int chin,
int hin, int hin,
int win, int win,
const void* weights, const float* weights,
const void* bias, const float* bias,
int group, int group,
int kernel_w, int kernel_w,
int kernel_h, int kernel_h,
...@@ -169,9 +131,7 @@ void conv_5x1s1_direct(const void* din, ...@@ -169,9 +131,7 @@ void conv_5x1s1_direct(const void* din,
int pad_h, int pad_h,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
Context<TARGET(kARM)>& ctx, ARMContext& ctx); // NOLINT
void* work_space,
const void* idx_ptr);
void conv1x1s1_gemm(const float* din, void conv1x1s1_gemm(const float* din,
float* dout, float* dout,
...@@ -185,11 +145,11 @@ void conv1x1s1_gemm(const float* din, ...@@ -185,11 +145,11 @@ void conv1x1s1_gemm(const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, ARMContext* ctx);
const int* idx_ptr);
template <typename Dtype>
void conv1x1s1_gemm_int8(const int8_t* din, void conv1x1s1_gemm_int8(const int8_t* din,
int32_t* dout, Dtype* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
...@@ -198,12 +158,10 @@ void conv1x1s1_gemm_int8(const int8_t* din, ...@@ -198,12 +158,10 @@ void conv1x1s1_gemm_int8(const int8_t* din,
int hin, int hin,
int win, int win,
const int8_t* weights, const int8_t* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, ARMContext* ctx,
PrecisionType out_type, const float* scale);
const float* scale,
const int32_t* idx_ptr);
void conv_im2col_gemm(const float* din, void conv_im2col_gemm(const float* din,
float* dout, float* dout,
...@@ -217,11 +175,11 @@ void conv_im2col_gemm(const float* din, ...@@ -217,11 +175,11 @@ void conv_im2col_gemm(const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, ARMContext* ctx);
const int* idx_ptr);
template <typename Dtype>
void conv_im2col_gemm_int8(const int8_t* din, void conv_im2col_gemm_int8(const int8_t* din,
int32_t* dout, Dtype* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
...@@ -230,157 +188,103 @@ void conv_im2col_gemm_int8(const int8_t* din, ...@@ -230,157 +188,103 @@ void conv_im2col_gemm_int8(const int8_t* din,
int hin, int hin,
int win, int win,
const int8_t* weights, const int8_t* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, ARMContext* ctx,
PrecisionType out_type, const float* scale);
const float* scale,
const int32_t* idx_ptr);
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias
*/
void conv_depthwise_3x3p0(const float* din,
float* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const float* weights,
const float* bias,
int stride,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
void conv_depthwise_3x3p1(const float* din,
float* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const float* weights,
const float* bias,
int stride,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
void conv_depthwise_5x5s1(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
void conv_depthwise_5x5s2(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
void conv_depthwise_3x3(const float* din, /// depthwise conv
float* dout, void conv_depthwise_3x3_fp32(const void* din,
int num, void* dout,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx);
void conv_depthwise_3x3_int8(const int8_t* din,
int32_t* dout,
int num, int num,
int chout, int ch_out,
int hout, int h_out,
int wout, int w_out,
int chin, int ch_in,
int hin, int h_in,
int win, int w_in,
const int8_t* weights, const void* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, ARMContext* ctx,
PrecisionType out_type,
const float* scale); const float* scale);
void conv_depthwise_3x3_int7(const int8_t* din, void conv_depthwise_3x3_int8_fp32(const void* din,
int32_t* dout, void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale);
void conv_depthwise_3x3_int8_int8(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale);
void conv_depthwise_5x5_fp32(const void* din,
void* dout,
int num, int num,
int chout, int ch_out,
int hout, int h_out,
int wout, int w_out,
int chin, int ch_in,
int hin, int h_in,
int win, int w_in,
int8_t* weights, const void* weights,
const int32_t* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx, ARMContext* ctx,
PrecisionType out_type,
const float* scale);
void conv_depthwise_5x5(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx);
void conv_depthwise_5x5_int8(const int8_t* din,
int32_t* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int8_t* weights,
const int32_t* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx,
PrecisionType out_type,
const float* scale); const float* scale);
void conv_depthwise_5x5_int8_fp32(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale);
void conv_depthwise_5x5_int8_int8(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale);
/// winograd conv, only support 3x3s1
void conv_winograd3x3(const float* din, void conv_winograd3x3(const float* din,
float* dout, float* dout,
int num, int num,
...@@ -393,23 +297,11 @@ void conv_winograd3x3(const float* din, ...@@ -393,23 +297,11 @@ void conv_winograd3x3(const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx); ARMContext* ctx);
void winograd_transform_weights( void winograd_transform_weights(
void* dout, const void* din, int ch_out, int ch_in, void* work_space); void* dout, const void* din, int ch_out, int ch_in, void* work_space);
void compute_offset(int* idx_out,
int h,
int w,
int kernel_h,
int kernel_w,
int height,
int width,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w);
void fill_bias(float* tensor, const float* bias, int channel, int channel_size); void fill_bias(float* tensor, const float* bias, int channel, int channel_size);
void fill_bias_int8(int* tensor, void fill_bias_int8(int* tensor,
......
...@@ -102,7 +102,7 @@ void conv_winograd3x3(const float* din, ...@@ -102,7 +102,7 @@ void conv_winograd3x3(const float* din,
//! dot mul //! dot mul
//! transpose input, convert from ch_in * tile_h * tile_w * 64 to //! transpose input, convert from ch_in * tile_h * tile_w * 64 to
//! 64 * ch_in * tile_h * tile_w //! 64 * ch_in * tile_h * tile_w
int hblock = get_hblock(ctx->arch()); int hblock = get_hblock(ctx);
int m_round = hblock * ((chout + hblock - 1) / hblock); int m_round = hblock * ((chout + hblock - 1) / hblock);
int stride_a = m_round * chin; int stride_a = m_round * chin;
int stride_b = chin * size_tile; int stride_b = chin * size_tile;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// clang-format off
#define GEMM_SDOT_INT8_KERNEL \
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a01 to q0, q1*/ \
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ \
"eor v8.16b, v8.16b, v8.16b\n" /* out0 = 0 */ \
"eor v9.16b, v9.16b, v9.16b\n" /* out1 = 0 */ \
"eor v10.16b, v10.16b, v10.16b\n" /* out2 = 0 */ \
"eor v11.16b, v11.16b, v11.16b\n" /* out3 = 0 */ \
"eor v12.16b, v12.16b, v12.16b\n" /* out4 = 0 */ \
"prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ \
"eor v13.16b, v13.16b, v13.16b\n" /* out5 = 0 */ \
"prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ \
"eor v14.16b, v14.16b, v14.16b\n" /* out6 = 0 */ \
"prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ \
"eor v15.16b, v15.16b, v15.16b\n" /* out7 = 0 */ \
"prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ \
"eor v16.16b, v16.16b, v16.16b\n" /* out8 = 0 */ \
"prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ \
"eor v17.16b, v17.16b, v17.16b\n" /* out9 = 0 */ \
"prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ \
"eor v18.16b, v18.16b, v18.16b\n" /* out10 = 0 */ \
"prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ \
"eor v19.16b, v19.16b, v19.16b\n" /* out11 = 0 */ \
"prfm pldl1keep, [%[b_ptr], #320]\n" /* preload b*/ \
"eor v20.16b, v20.16b, v20.16b\n" /* out12 = 0 */ \
"prfm pldl1keep, [%[a_ptr], #256]\n" /* preload a*/ \
"eor v21.16b, v21.16b, v21.16b\n" /* out13 = 0 */ \
"prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ \
"eor v22.16b, v22.16b, v22.16b\n" /* out14 = 0 */ \
"eor v23.16b, v23.16b, v23.16b\n" /* out15 = 0 */ \
"eor v24.16b, v24.16b, v24.16b\n" /* out16 = 0 */ \
"eor v25.16b, v25.16b, v25.16b\n" /* out17 = 0 */ \
"eor v26.16b, v26.16b, v26.16b\n" /* out18 = 0 */ \
"eor v27.16b, v27.16b, v27.16b\n" /* out19 = 0 */ \
"eor v28.16b, v28.16b, v28.16b\n" /* out20 = 0 */ \
"eor v29.16b, v29.16b, v29.16b\n" /* out21 = 0 */ \
"eor v30.16b, v30.16b, v30.16b\n" /* out22 = 0 */ \
"eor v31.16b, v31.16b, v31.16b\n" /* out23 = 0 */ \
"cbz %w[k], 2f\n" /* check loop count > 0 */ \
/* main loop, unrool 0*/ \
"1:\n" /* main loop */ \
"sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q4 */ \
"sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q4 */ \
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7 */ \
"sdot v14.4s, v4.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q4 */ \
"sdot v17.4s, v4.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q4 */ \
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4 */ \
"sdot v20.4s, v4.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q4 */ \
"sdot v23.4s, v4.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q4 */ \
"sdot v26.4s, v4.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q4 */ \
"sdot v29.4s, v4.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q4 */ \
"sdot v9.4s, v5.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q5 */ \
"sdot v12.4s, v5.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q5 */ \
"sdot v15.4s, v5.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q5*/ \
"sdot v18.4s, v5.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q5*/ \
"sdot v21.4s, v5.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q5*/ \
"sdot v24.4s, v5.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q5*/ \
"sdot v27.4s, v5.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q5*/ \
"sdot v30.4s, v5.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q5*/ \
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5 */ \
"sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q6*/ \
"sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q6*/ \
"prfm pldl1keep, [%[b_ptr], #384]\n" \
"sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q6*/ \
"sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q6*/ \
"sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q6*/ \
"sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q6*/ \
"sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q6*/ \
"sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q6*/ \
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */ \
/* unrool 1 */ \
"sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q7 */ \
"sdot v11.4s , v7.16b, v2.4b[1]\n"/* out1 = b0 * a10[1], b0 = q7 */ \
"sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q7 */ \
"prfm pldl1keep, [%[a_ptr], #256]\n" \
"sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q7 */ \
"sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q7 */ \
"sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q7 */ \
"sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q7 */ \
"sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q7 */ \
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */ \
"sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q4 */ \
"sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q4 */ \
"sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q4*/ \
"sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q4*/ \
"sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q4*/ \
"sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q4*/ \
"sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q4*/ \
"sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q4*/ \
"sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q5*/ \
"sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q5*/ \
"sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q5*/ \
"sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q5*/ \
"sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q5*/ \
"sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q5*/ \
"sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q5*/ \
"sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q5*/ \
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */ \
/* unrool 2*/ \
"sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q6 */ \
"sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q6 */ \
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ \
"sdot v14.4s, v6.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q6*/ \
"sdot v17.4s, v6.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q6*/ \
"sdot v20.4s, v6.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q6*/ \
"sdot v23.4s, v6.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q6*/ \
"sdot v26.4s, v6.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q6*/ \
"sdot v29.4s, v6.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q6*/ \
"sdot v9.4s, v7.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q7*/ \
"sdot v12.4s, v7.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q7*/ \
"prfm pldl1keep, [%[b_ptr], #384]\n" \
"sdot v15.4s, v7.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q7*/ \
"sdot v18.4s, v7.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q7*/ \
"sdot v21.4s, v7.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q7*/ \
"sdot v24.4s, v7.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q7*/ \
"sdot v27.4s, v7.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q7*/ \
"sdot v30.4s, v7.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q7*/ \
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ \
"sdot v10.4s, v4.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q4*/ \
"sdot v13.4s, v4.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q4*/ \
"sdot v16.4s, v4.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q4*/ \
"sdot v19.4s, v4.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q4*/ \
"sdot v22.4s, v4.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q4*/ \
"sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q4*/ \
"sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q4*/ \
"sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q4*/ \
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ \
/* unrool 3*/ \
"sdot v8.4s , v5.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \
"sdot v11.4s , v5.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \
"sdot v14.4s, v5.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \
"sdot v17.4s, v5.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \
"sdot v20.4s, v5.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \
"sdot v23.4s, v5.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \
"sdot v26.4s, v5.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \
"sdot v29.4s, v5.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ \
"sdot v9.4s, v6.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \
"sdot v12.4s, v6.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q6*/ \
"prfm pldl1keep, [%[a_ptr], #256]\n" \
"sdot v15.4s, v6.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \
"sdot v18.4s, v6.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \
"sdot v21.4s, v6.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \
"sdot v24.4s, v6.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \
"sdot v27.4s, v6.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \
"prfm pldl1keep, [%[b_ptr], #384]\n" \
"sdot v30.4s, v6.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \
"sdot v10.4s, v7.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \
"sdot v13.4s, v7.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \
"sdot v16.4s, v7.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \
"sdot v19.4s, v7.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \
"sdot v22.4s, v7.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \
"sdot v25.4s, v7.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \
"subs %w[k], %w[k], #1\n" /* loop count - 1*/ \
"sdot v28.4s, v7.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \
"sdot v31.4s, v7.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \
"bne 1b\n" /* Target to use when K is 1 or 2 */ \
"2:\n" /* process tail*/ \
"subs %w[tail], %w[tail], #1\n" /* tail--*/ \
"beq 3f\n" /*jump to tail = 1*/ \
/* final unrool 0, unrool 0, tail > 1*/ \
"sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q4*/ \
"sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q4*/ \
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7*/ \
"sdot v14.4s, v4.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q4*/ \
"sdot v17.4s, v4.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q4*/ \
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q2, q3*/ \
"sdot v20.4s, v4.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q4*/ \
"sdot v23.4s, v4.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q4*/ \
"sdot v26.4s, v4.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q4*/ \
"sdot v29.4s, v4.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q4*/ \
"subs %w[tail], %w[tail], #1\n" /* tail--*/ \
"sdot v9.4s, v5.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q5*/ \
"sdot v12.4s, v5.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q5*/ \
"sdot v15.4s, v5.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q5*/ \
"sdot v18.4s, v5.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q5*/ \
"sdot v21.4s, v5.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q5*/ \
"sdot v24.4s, v5.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q5*/ \
"sdot v27.4s, v5.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q5*/ \
"sdot v30.4s, v5.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q5*/ \
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5*/ \
"sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q6*/ \
"sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q6*/ \
"sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q6*/ \
"sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q6*/ \
"sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q6*/ \
"sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q6*/ \
"sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q6*/ \
"sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q6*/ \
"beq 4f\n" /*jump to tail = 2*/ \
/* unrool 1, tail > 2*/ \
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ \
"sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q7*/ \
"sdot v11.4s , v7.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q7*/ \
"sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q7*/ \
"sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q7*/ \
"sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q7*/ \
"sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q7*/ \
"sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q7*/ \
"sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q7*/ \
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7*/ \
"sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q4*/ \
"sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q4*/ \
"sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q4*/ \
"sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q4*/ \
"sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q4*/ \
"sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q4*/ \
"sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q4*/ \
"sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q4*/ \
"subs %w[tail], %w[tail], #1\n" /* tail--*/ \
"sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q5*/ \
"sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q5*/ \
"sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q5*/ \
"sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q5*/ \
"sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q5*/ \
"sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q5*/ \
"sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q5*/ \
"sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q5*/ \
"beq 5f\n" /*jump to tail = 3*/ \
/* unrool 2, tail = 4*/ \
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5*/ \
"sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q6*/ \
"sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q6*/ \
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ \
"sdot v14.4s, v6.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q6*/ \
"sdot v17.4s, v6.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q6*/ \
"sdot v20.4s, v6.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q6*/ \
"sdot v23.4s, v6.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q6*/ \
"sdot v26.4s, v6.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q6*/ \
"sdot v29.4s, v6.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q6*/ \
"sdot v9.4s, v7.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q7*/ \
"sdot v12.4s, v7.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q7*/ \
"sdot v15.4s, v7.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q7*/ \
"sdot v18.4s, v7.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q7*/ \
"sdot v21.4s, v7.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q7*/ \
"sdot v24.4s, v7.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q7*/ \
"sdot v27.4s, v7.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q7*/ \
"sdot v30.4s, v7.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q7*/ \
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ \
"sdot v10.4s, v4.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q4*/ \
"sdot v13.4s, v4.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q4*/ \
"sdot v16.4s, v4.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q4*/ \
"sdot v19.4s, v4.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q4*/ \
"sdot v22.4s, v4.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q4*/ \
"sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q4*/ \
"sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q4*/ \
"sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q4*/ \
/* unrool 3, tail = 4*/ \
"sdot v8.4s , v5.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \
"sdot v11.4s , v5.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \
"sdot v14.4s, v5.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \
"sdot v17.4s, v5.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \
"sdot v20.4s, v5.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \
"sdot v23.4s, v5.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \
"sdot v26.4s, v5.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \
"sdot v29.4s, v5.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \
"sdot v9.4s, v6.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \
"sdot v12.4s, v6.16b, v2.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \
"sdot v15.4s, v6.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \
"sdot v18.4s, v6.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \
"sdot v21.4s, v6.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \
"sdot v24.4s, v6.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \
"sdot v27.4s, v6.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \
"sdot v30.4s, v6.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \
"sdot v10.4s, v7.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \
"sdot v13.4s, v7.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \
"sdot v16.4s, v7.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \
"sdot v19.4s, v7.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \
"sdot v22.4s, v7.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \
"sdot v25.4s, v7.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \
"sdot v28.4s, v7.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \
"sdot v31.4s, v7.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \
"b 11f\n" /* tails==1 final tail*/ \
"3: \n" /* tail=1*/ \
"ldr q6, [%[b_ptr]], #16\n" /* load b2 to q6*/ \
"sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \
"sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \
"sdot v14.4s, v4.16b, v0.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \
"sdot v17.4s, v4.16b, v0.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \
"sdot v20.4s, v4.16b, v1.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \
"sdot v23.4s, v4.16b, v1.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \
"sdot v26.4s, v4.16b, v1.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \
"sdot v29.4s, v4.16b, v1.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \
"sdot v9.4s, v5.16b, v0.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \
"sdot v12.4s, v5.16b, v0.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \
"sdot v15.4s, v5.16b, v0.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \
"sdot v18.4s, v5.16b, v0.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \
"sdot v21.4s, v5.16b, v1.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \
"sdot v24.4s, v5.16b, v1.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \
"sdot v27.4s, v5.16b, v1.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \
"sdot v30.4s, v5.16b, v1.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \
"sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \
"sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \
"sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \
"sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \
"sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \
"sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \
"sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \
"sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \
"b 11f\n" /* tails==2 final tail*/ \
"4:\n" /* tail = 2*/ \
"sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \
"sdot v11.4s , v7.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \
"sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \
"sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \
"sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \
"sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \
"sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \
"sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \
"sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \
"sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \
"sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \
"sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \
"sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \
"sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \
"sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \
"sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \
"sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \
"sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \
"sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \
"sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \
"sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \
"sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \
"sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \
"sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \
"b 11f\n" /* tails==3 final tail*/ \
"5:\n" /* tail = 3*/ \
"ldr q4, [%[b_ptr]], #16\n" /* load b2, b0 to q4*/ \
"sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \
"sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \
"sdot v14.4s, v6.16b, v0.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \
"sdot v17.4s, v6.16b, v0.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \
"sdot v20.4s, v6.16b, v1.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \
"sdot v23.4s, v6.16b, v1.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \
"sdot v26.4s, v6.16b, v1.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \
"sdot v29.4s, v6.16b, v1.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \
"sdot v9.4s, v7.16b, v0.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \
"sdot v12.4s, v7.16b, v0.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \
"sdot v15.4s, v7.16b, v0.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \
"sdot v18.4s, v7.16b, v0.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \
"sdot v21.4s, v7.16b, v1.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \
"sdot v24.4s, v7.16b, v1.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \
"sdot v27.4s, v7.16b, v1.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \
"sdot v30.4s, v7.16b, v1.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \
"sdot v10.4s, v4.16b, v0.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \
"sdot v13.4s, v4.16b, v0.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \
"sdot v16.4s, v4.16b, v0.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \
"sdot v19.4s, v4.16b, v0.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \
"sdot v22.4s, v4.16b, v1.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \
"sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \
"sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \
"sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \
"11: \n" /* end */
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include "lite/backends/arm/math/box_coder.h" #include "lite/backends/arm/math/box_coder.h"
#include "lite/backends/arm/math/col_im_transform.h" #include "lite/backends/arm/math/col_im_transform.h"
#include "lite/backends/arm/math/concat.h" #include "lite/backends/arm/math/concat.h"
#include "lite/backends/arm/math/conv_depthwise.h" #include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_direct.h" #include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/conv_gemmlike.h"
#include "lite/backends/arm/math/conv_winograd.h"
#include "lite/backends/arm/math/decode_bboxes.h" #include "lite/backends/arm/math/decode_bboxes.h"
#include "lite/backends/arm/math/dropout.h" #include "lite/backends/arm/math/dropout.h"
#include "lite/backends/arm/math/elementwise.h" #include "lite/backends/arm/math/elementwise.h"
#include "lite/backends/arm/math/fill_bias_relu.h" #include "lite/backends/arm/math/fill_bias_relu.h"
#include "lite/backends/arm/math/gemm_prepacked_int8.h"
#include "lite/backends/arm/math/gemm_s8.h"
#include "lite/backends/arm/math/gemv_arm_int8.h"
#include "lite/backends/arm/math/im2sequence.h" #include "lite/backends/arm/math/im2sequence.h"
#include "lite/backends/arm/math/increment.h" #include "lite/backends/arm/math/increment.h"
#include "lite/backends/arm/math/interpolate.h" #include "lite/backends/arm/math/interpolate.h"
...@@ -61,6 +62,7 @@ ...@@ -61,6 +62,7 @@
#include "lite/backends/arm/math/stack.h" #include "lite/backends/arm/math/stack.h"
#include "lite/backends/arm/math/topk.h" #include "lite/backends/arm/math/topk.h"
#include "lite/backends/arm/math/yolo_box.h" #include "lite/backends/arm/math/yolo_box.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
...@@ -261,7 +263,7 @@ inline float32x4_t exp_ps(float32x4_t x) { ...@@ -261,7 +263,7 @@ inline float32x4_t exp_ps(float32x4_t x) {
// almost no extra price so both sin_ps and cos_ps make use of // almost no extra price so both sin_ps and cos_ps make use of
// sincos_ps.. // sincos_ps..
// //
inline void sincos_ps(float32x4_t x, float32x4_t *ysin, float32x4_t *ycos) { inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos) {
// any x // any x
float32x4_t xmm1, xmm2, xmm3, y; float32x4_t xmm1, xmm2, xmm3, y;
...@@ -350,23 +352,23 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { ...@@ -350,23 +352,23 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) {
} }
template <typename T> template <typename T>
void fill_bias_fc(T *tensor, const T *bias, int num, int channel); void fill_bias_fc(T* tensor, const T* bias, int num, int channel);
template <lite_api::ActivationType Act = lite_api::ActivationType::kIndentity> template <lite_api::ActivationType Act = lite_api::ActivationType::kIndentity>
inline float32x4_t vactive_f32(const float32x4_t &x) { inline float32x4_t vactive_f32(const float32x4_t& x) {
return x; return x;
} }
template <> template <>
inline float32x4_t vactive_f32<lite_api::ActivationType::kRelu>( inline float32x4_t vactive_f32<lite_api::ActivationType::kRelu>(
const float32x4_t &x) { const float32x4_t& x) {
float32x4_t __zero = vdupq_n_f32(0.f); float32x4_t __zero = vdupq_n_f32(0.f);
return vmaxq_f32(x, __zero); return vmaxq_f32(x, __zero);
} }
template <> template <>
inline float32x4_t vactive_f32<lite_api::ActivationType::kRelu6>( inline float32x4_t vactive_f32<lite_api::ActivationType::kRelu6>(
const float32x4_t &x) { const float32x4_t& x) {
float32x4_t __zero = vdupq_n_f32(0.f); float32x4_t __zero = vdupq_n_f32(0.f);
float32x4_t __six = vdupq_n_f32(6.f); float32x4_t __six = vdupq_n_f32(6.f);
return vminq_f32(vmaxq_f32(x, __zero), __six); return vminq_f32(vmaxq_f32(x, __zero), __six);
...@@ -374,7 +376,7 @@ inline float32x4_t vactive_f32<lite_api::ActivationType::kRelu6>( ...@@ -374,7 +376,7 @@ inline float32x4_t vactive_f32<lite_api::ActivationType::kRelu6>(
template <> template <>
inline float32x4_t vactive_f32<lite_api::ActivationType::kSigmoid>( inline float32x4_t vactive_f32<lite_api::ActivationType::kSigmoid>(
const float32x4_t &x) { const float32x4_t& x) {
float32x4_t __one = vdupq_n_f32(1.f); float32x4_t __one = vdupq_n_f32(1.f);
float32x4_t __x = vnegq_f32(x); float32x4_t __x = vnegq_f32(x);
__x = exp_ps(__x); __x = exp_ps(__x);
...@@ -385,7 +387,7 @@ inline float32x4_t vactive_f32<lite_api::ActivationType::kSigmoid>( ...@@ -385,7 +387,7 @@ inline float32x4_t vactive_f32<lite_api::ActivationType::kSigmoid>(
template <> template <>
inline float32x4_t vactive_f32<lite_api::ActivationType::kTanh>( inline float32x4_t vactive_f32<lite_api::ActivationType::kTanh>(
const float32x4_t &x) { const float32x4_t& x) {
float32x4_t __one = vdupq_n_f32(1.f); float32x4_t __one = vdupq_n_f32(1.f);
float32x4_t __x = vmulq_n_f32(x, -2.f); float32x4_t __x = vmulq_n_f32(x, -2.f);
__x = exp_ps(__x); __x = exp_ps(__x);
...@@ -397,27 +399,27 @@ inline float32x4_t vactive_f32<lite_api::ActivationType::kTanh>( ...@@ -397,27 +399,27 @@ inline float32x4_t vactive_f32<lite_api::ActivationType::kTanh>(
} }
template <lite_api::ActivationType Act = lite_api::ActivationType::kIndentity> template <lite_api::ActivationType Act = lite_api::ActivationType::kIndentity>
inline float active_f32(const float &x) { inline float active_f32(const float& x) {
return x; return x;
} }
template <> template <>
inline float active_f32<lite_api::ActivationType::kRelu>(const float &x) { inline float active_f32<lite_api::ActivationType::kRelu>(const float& x) {
return std::max(x, 0.f); return std::max(x, 0.f);
} }
template <> template <>
inline float active_f32<lite_api::ActivationType::kRelu6>(const float &x) { inline float active_f32<lite_api::ActivationType::kRelu6>(const float& x) {
return std::min(std::max(x, 0.f), 6.f); return std::min(std::max(x, 0.f), 6.f);
} }
template <> template <>
inline float active_f32<lite_api::ActivationType::kSigmoid>(const float &x) { inline float active_f32<lite_api::ActivationType::kSigmoid>(const float& x) {
return 1.f / (1.f + exp(-x)); return 1.f / (1.f + exp(-x));
} }
template <> template <>
inline float active_f32<lite_api::ActivationType::kTanh>(const float &x) { inline float active_f32<lite_api::ActivationType::kTanh>(const float& x) {
return 2.f / (1.f + exp(-2.f * x)) - 1.f; return 2.f / (1.f + exp(-2.f * x)) - 1.f;
} }
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include <cmath> #include <cmath>
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/device_info.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
namespace paddle { namespace paddle {
...@@ -34,7 +33,7 @@ const int NBLOCK_INT8_OTH = 16; ...@@ -34,7 +33,7 @@ const int NBLOCK_INT8_OTH = 16;
const int MBLOCK_INT8_DOT = 8; const int MBLOCK_INT8_DOT = 8;
const int NBLOCK_INT8_DOT = 12; const int NBLOCK_INT8_DOT = 12;
inline int get_hblock_int8(const ARMContext* ctx) { inline int get_hblock_int8(ARMContext* ctx) {
#ifdef WITH_ARM_DOTPROD #ifdef WITH_ARM_DOTPROD
if (ctx->has_dot()) { if (ctx->has_dot()) {
return MBLOCK_INT8_DOT; return MBLOCK_INT8_DOT;
...@@ -51,7 +50,7 @@ inline int get_hblock_int8(const ARMContext* ctx) { ...@@ -51,7 +50,7 @@ inline int get_hblock_int8(const ARMContext* ctx) {
const int MBLOCK_INT8_OTH = 4; const int MBLOCK_INT8_OTH = 4;
const int NBLOCK_INT8_OTH = 8; const int NBLOCK_INT8_OTH = 8;
inline int get_hblock_int8(const ARMContext* ctx) { return 4; } inline int get_hblock_int8(ARMContext* ctx) { return 4; }
#endif // __aarch64__ #endif // __aarch64__
void prepackA_int8(void* out, void prepackA_int8(void* out,
...@@ -75,7 +74,7 @@ void prepackA_int8(TensorLite* tout, ...@@ -75,7 +74,7 @@ void prepackA_int8(TensorLite* tout,
template <typename dtype> template <typename dtype>
void gemm_prepack_int8(const int8_t* A_packed, void gemm_prepack_int8(const int8_t* A_packed,
const int8_t* B, const int8_t* B,
const int* bias, const float* bias,
dtype* C, dtype* C,
int M, int M,
int N, int N,
...@@ -87,7 +86,6 @@ void gemm_prepack_int8(const int8_t* A_packed, ...@@ -87,7 +86,6 @@ void gemm_prepack_int8(const int8_t* A_packed,
ARMContext* ctx); ARMContext* ctx);
#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) #define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/gemm_s8.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename Dtype>
void gemm_s8(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
Dtype* C,
const float* bias,
bool is_bias,
bool is_relu,
const float* scale,
ARMContext* ctx) {
int hblock = get_hblock_int8(ctx);
int m_roundup = hblock * ((M + hblock - 1) / hblock);
auto packed_A = static_cast<int8_t*>(
TargetMalloc(TargetType::kARM, m_roundup * K * sizeof(int8_t)));
int lda = is_transA ? M : K;
prepackA_int8(packed_A, A, lda, 0, M, 0, K, is_transA, ctx);
gemm_prepack_int8(
packed_A, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx);
TargetFree(TargetType::kARM, packed_A);
}
template void gemm_s8<float>(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
float* C,
const float* bias,
bool is_bias,
bool is_relu,
const float* scale,
ARMContext* ctx);
template void gemm_s8<int8_t>(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
int8_t* C,
const float* bias,
bool is_bias,
bool is_relu,
const float* scale,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include "lite/backends/arm/math/gemm_prepacked_int8.h"
#include "lite/core/context.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename Dtype>
void gemm_s8(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
Dtype* C,
const float* bias,
bool is_bias,
bool is_relu,
const float* scale,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -22,36 +22,90 @@ namespace arm { ...@@ -22,36 +22,90 @@ namespace arm {
namespace math { namespace math {
template <typename dtype> template <typename dtype>
inline void write_gemv_out(const int* in, dtype* out, const float* scale); inline void write_gemv_out(const int* in,
dtype* out,
template <> const float* scale,
inline void write_gemv_out(const int* in, int* out, const float* scale) { const float* bias,
out[0] = in[0]; int size,
} bool is_relu);
template <> template <>
inline void write_gemv_out(const int* in, float* out, const float* scale) { inline void write_gemv_out(const int* in,
out[0] = in[0] * scale[0]; float* out,
const float* scale,
const float* bias,
int size,
bool is_relu) {
int i = 0;
float32x4_t vzero = vdupq_n_f32(0.f);
for (; i < size - 7; i += 8) {
float32x4_t vout0 = bias ? vld1q_f32(bias) : vdupq_n_f32(0.f);
float32x4_t vout1 = bias ? vld1q_f32(bias + 4) : vdupq_n_f32(0.f);
int32x4_t vin0 = vld1q_s32(in);
int32x4_t vin1 = vld1q_s32(in + 4);
float32x4_t vscale0 = vld1q_f32(scale);
float32x4_t vscale1 = vld1q_f32(scale + 4);
float32x4_t vinf0 = vcvtq_f32_s32(vin0);
float32x4_t vinf1 = vcvtq_f32_s32(vin1);
vout0 = vmlaq_f32(vout0, vinf0, vscale0);
vout1 = vmlaq_f32(vout1, vinf1, vscale1);
if (is_relu) {
vout0 = vmaxq_f32(vout0, vzero);
vout1 = vmaxq_f32(vout1, vzero);
}
vst1q_f32(out, vout0);
vst1q_f32(out + 4, vout1);
bias += 8;
in += 8;
out += 8;
scale += 8;
}
for (; i < size; ++i) {
out[0] = *(in++) * *(scale)++;
out[0] += bias ? *(bias++) : 0.f;
out[0] = is_relu ? (out[0] > 0.f ? out[0] : 0.f) : out[0];
out++;
}
} }
template <> template <>
inline void write_gemv_out(const int* in, inline void write_gemv_out(const int* in,
signed char* out, signed char* out,
const float* scale) { const float* scale,
out[0] = saturate_cast<signed char>(roundf(in[0] * scale[0])); const float* bias,
int size,
bool flag_relu) {
if (bias) {
for (int i = 0; i < size; ++i) {
out[0] =
saturate_cast<signed char>(roundf(*(in++) * *(scale++) + *(bias++)));
if (flag_relu) {
out[0] = out[0] > 0 ? out[0] : 0;
}
out++;
}
} else {
for (int i = 0; i < size; ++i) {
out[0] = saturate_cast<signed char>(roundf(*(in++) * *(scale++)));
if (flag_relu) {
out[0] = out[0] > 0 ? out[0] : 0;
}
out++;
}
}
} }
template <typename dtype> template <typename dtype>
bool gemv_int8(const int8_t* A, bool gemv_int8_oth(const int8_t* A,
const int8_t* x, const int8_t* x,
dtype* y, dtype* y,
bool transA, bool transA,
int M, int M,
int N, int N,
const float* scale, const float* scale,
bool is_bias, bool is_bias,
const int* bias, const float* bias,
bool is_relu) { bool is_relu) {
if (transA) { if (transA) {
LOG(ERROR) << "ERROR: sgemv, transA is not supported now"; LOG(ERROR) << "ERROR: sgemv, transA is not supported now";
return false; return false;
...@@ -61,7 +115,6 @@ bool gemv_int8(const int8_t* A, ...@@ -61,7 +115,6 @@ bool gemv_int8(const int8_t* A,
const int8_t* weights_ptr = A; const int8_t* weights_ptr = A;
int cnt = N >> 4; int cnt = N >> 4;
int tail = N & 15; int tail = N & 15;
int flag_bias = is_bias ? 1 : 0;
#ifdef __aarch64__ #ifdef __aarch64__
int out_cnt = M >> 3; int out_cnt = M >> 3;
...@@ -80,7 +133,7 @@ bool gemv_int8(const int8_t* A, ...@@ -80,7 +133,7 @@ bool gemv_int8(const int8_t* A,
const int8_t* ptr_w5 = ptr_w4 + N; const int8_t* ptr_w5 = ptr_w4 + N;
const int8_t* ptr_w6 = ptr_w5 + N; const int8_t* ptr_w6 = ptr_w5 + N;
const int8_t* ptr_w7 = ptr_w6 + N; const int8_t* ptr_w7 = ptr_w6 + N;
const int* bias_ptr = is_bias ? (bias + out_idx) : nullptr; auto bias_ptr = is_bias ? bias + out_idx : nullptr;
int cnt_loop = cnt; int cnt_loop = cnt;
asm volatile( asm volatile(
"prfm pldl1keep, [%[in]] \n" /* preload din */ "prfm pldl1keep, [%[in]] \n" /* preload din */
...@@ -153,13 +206,6 @@ bool gemv_int8(const int8_t* A, ...@@ -153,13 +206,6 @@ bool gemv_int8(const int8_t* A,
"addp v12.4s, v8.4s , v9.4s \n" /* pair add to 4 int32*/ "addp v12.4s, v8.4s , v9.4s \n" /* pair add to 4 int32*/
"addp v13.4s, v10.4s, v11.4s \n" /* pair add to 4 int32*/ "addp v13.4s, v10.4s, v11.4s \n" /* pair add to 4 int32*/
"cmp %w[bias], #1 \n" /* check whether has bias */
"blt 0f \n" /* jump to tail */
"ldp q8, q9, [%[bias_ptr]]\n" /* load bias to q8, q9*/
"add v12.4s, v12.4s, v8.4s \n" /* add bias */
"add v13.4s, v13.4s, v9.4s \n" /* add bias */
"0: \n" /* end of add bias */
/* write to output */ /* write to output */
"stp q12, q13, [%[out]] \n" /* save result */ "stp q12, q13, [%[out]] \n" /* save result */
: [in] "+r"(ptr_in), : [in] "+r"(ptr_in),
...@@ -172,7 +218,7 @@ bool gemv_int8(const int8_t* A, ...@@ -172,7 +218,7 @@ bool gemv_int8(const int8_t* A,
[w6] "+r"(ptr_w6), [w6] "+r"(ptr_w6),
[w7] "+r"(ptr_w7), [w7] "+r"(ptr_w7),
[cnt] "+r"(cnt_loop) [cnt] "+r"(cnt_loop)
: [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr), [bias] "r"(flag_bias) : [out] "r"(ptr_out)
: "cc", : "cc",
"memory", "memory",
"v0", "v0",
...@@ -211,25 +257,8 @@ bool gemv_int8(const int8_t* A, ...@@ -211,25 +257,8 @@ bool gemv_int8(const int8_t* A,
ptr_out[6] += ptr_in[i] * ptr_w6[i]; ptr_out[6] += ptr_in[i] * ptr_w6[i];
ptr_out[7] += ptr_in[i] * ptr_w7[i]; ptr_out[7] += ptr_in[i] * ptr_w7[i];
} }
if (is_relu) {
ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0;
ptr_out[1] = ptr_out[1] > 0 ? ptr_out[1] : 0;
ptr_out[2] = ptr_out[2] > 0 ? ptr_out[2] : 0;
ptr_out[3] = ptr_out[3] > 0 ? ptr_out[3] : 0;
ptr_out[4] = ptr_out[4] > 0 ? ptr_out[4] : 0;
ptr_out[5] = ptr_out[5] > 0 ? ptr_out[5] : 0;
ptr_out[6] = ptr_out[6] > 0 ? ptr_out[6] : 0;
ptr_out[7] = ptr_out[7] > 0 ? ptr_out[7] : 0;
}
write_gemv_out(ptr_out, out_ptr, scale_ptr); write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, is_relu);
write_gemv_out(ptr_out + 1, out_ptr + 1, scale_ptr + 1);
write_gemv_out(ptr_out + 2, out_ptr + 2, scale_ptr + 2);
write_gemv_out(ptr_out + 3, out_ptr + 3, scale_ptr + 3);
write_gemv_out(ptr_out + 4, out_ptr + 4, scale_ptr + 4);
write_gemv_out(ptr_out + 5, out_ptr + 5, scale_ptr + 5);
write_gemv_out(ptr_out + 6, out_ptr + 6, scale_ptr + 6);
write_gemv_out(ptr_out + 7, out_ptr + 7, scale_ptr + 7);
} }
//! deal with remains //! deal with remains
...@@ -242,12 +271,11 @@ bool gemv_int8(const int8_t* A, ...@@ -242,12 +271,11 @@ bool gemv_int8(const int8_t* A,
const int8_t* ptr_in = data_in; const int8_t* ptr_in = data_in;
const int8_t* ptr_w0 = weights_ptr + (N * j); const int8_t* ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt; int cnt_loop = cnt;
int bias0 = is_bias ? bias[j] : 0; auto bias_ptr = is_bias ? bias + j : nullptr;
asm volatile( asm volatile(
"prfm pldl1keep, [%[in]] \n" /* preload din */ "prfm pldl1keep, [%[in]] \n" /* preload din */
"prfm pldl1keep, [%[w0]] \n" /* preload w0 */ "prfm pldl1keep, [%[w0]] \n" /* preload w0 */
"movi v0.4s, #0 \n" /* set out0 to 0 */ "movi v0.4s, #0 \n" /* set out0 to 0 */
"fmov s0, %w[bias0] \n" /* set bias */
/* check main loop */ /* check main loop */
"cmp %w[cnt], #1 \n" /* check whether has main loop */ "cmp %w[cnt], #1 \n" /* check whether has main loop */
"blt 2f \n" /* jump to tail */ "blt 2f \n" /* jump to tail */
...@@ -269,17 +297,14 @@ bool gemv_int8(const int8_t* A, ...@@ -269,17 +297,14 @@ bool gemv_int8(const int8_t* A,
/* write to output */ /* write to output */
"str s8, [%[out]] \n" /* save result */ "str s8, [%[out]] \n" /* save result */
: [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop) : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop)
: [out] "r"(ptr_out), [bias0] "r"(bias0) : [out] "r"(ptr_out)
: "cc", "memory", "v0", "v8", "v9", "v18"); : "cc", "memory", "v0", "v8", "v9", "v18");
for (int i = 0; i < tail; ++i) { for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i]; ptr_out[0] += ptr_in[i] * ptr_w0[i];
} }
if (is_relu) { write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu);
ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0;
}
write_gemv_out(ptr_out, out_ptr, scale_ptr);
} }
#else //__aarch64__ // NOLINT #else // __aarch64__
int out_cnt = M >> 2; int out_cnt = M >> 2;
#pragma omp parallel for #pragma omp parallel for
for (int j = 0; j < out_cnt; j++) { for (int j = 0; j < out_cnt; j++) {
...@@ -293,10 +318,7 @@ bool gemv_int8(const int8_t* A, ...@@ -293,10 +318,7 @@ bool gemv_int8(const int8_t* A,
const int8_t* ptr_w2 = ptr_w1 + N; const int8_t* ptr_w2 = ptr_w1 + N;
const int8_t* ptr_w3 = ptr_w2 + N; const int8_t* ptr_w3 = ptr_w2 + N;
int cnt_loop = cnt; int cnt_loop = cnt;
int bias0 = is_bias ? bias[out_idx] : 0; auto bias_ptr = is_bias ? bias + out_idx : nullptr;
int bias1 = is_bias ? bias[out_idx + 1] : 0;
int bias2 = is_bias ? bias[out_idx + 2] : 0;
int bias3 = is_bias ? bias[out_idx + 3] : 0;
asm volatile( asm volatile(
"pld [%[in]] @ preload cache line, input\n" "pld [%[in]] @ preload cache line, input\n"
"pld [%[w0]] @ preload cache line, weights r0\n" "pld [%[w0]] @ preload cache line, weights r0\n"
...@@ -307,10 +329,6 @@ bool gemv_int8(const int8_t* A, ...@@ -307,10 +329,6 @@ bool gemv_int8(const int8_t* A,
"vmov.u32 q1, #0 @ set q1 to 0\n" "vmov.u32 q1, #0 @ set q1 to 0\n"
"vmov.u32 q2, #0 @ set q2 to 0\n" "vmov.u32 q2, #0 @ set q2 to 0\n"
"vmov.u32 q3, #0 @ set q3 to 0\n" "vmov.u32 q3, #0 @ set q3 to 0\n"
"vmov s0, %[bias0] @ set q0 to bias0\n"
"vmov s4, %[bias1] @ set q1 to bias1\n"
"vmov s8, %[bias2] @ set q2 to bias2\n"
"vmov s12,%[bias3] @ set q3 to bias3\n"
// "vld1.32 {d20-d21}, %[bias] @ load bias data" // "vld1.32 {d20-d21}, %[bias] @ load bias data"
"cmp %[cnt], #1 @ check whether has main loop\n" "cmp %[cnt], #1 @ check whether has main loop\n"
"blt 2f @ jump to pair add\n" "blt 2f @ jump to pair add\n"
...@@ -355,11 +373,7 @@ bool gemv_int8(const int8_t* A, ...@@ -355,11 +373,7 @@ bool gemv_int8(const int8_t* A,
[w2] "+r"(ptr_w2), [w2] "+r"(ptr_w2),
[w3] "+r"(ptr_w3), [w3] "+r"(ptr_w3),
[cnt] "+r"(cnt_loop) [cnt] "+r"(cnt_loop)
: [bias0] "r"(bias0), : [out] "r"(ptr_out)
[bias1] "r"(bias1),
[bias2] "r"(bias2),
[bias3] "r"(bias3),
[out] "r"(ptr_out)
: "cc", : "cc",
"memory", "memory",
"q0", "q0",
...@@ -382,16 +396,7 @@ bool gemv_int8(const int8_t* A, ...@@ -382,16 +396,7 @@ bool gemv_int8(const int8_t* A,
ptr_out[2] += ptr_in[i] * ptr_w2[i]; ptr_out[2] += ptr_in[i] * ptr_w2[i];
ptr_out[3] += ptr_in[i] * ptr_w3[i]; ptr_out[3] += ptr_in[i] * ptr_w3[i];
} }
if (is_relu) { write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 4, is_relu);
ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0;
ptr_out[1] = ptr_out[1] > 0 ? ptr_out[1] : 0;
ptr_out[2] = ptr_out[2] > 0 ? ptr_out[2] : 0;
ptr_out[3] = ptr_out[3] > 0 ? ptr_out[3] : 0;
}
write_gemv_out(ptr_out, out_ptr, scale_ptr);
write_gemv_out(ptr_out + 1, out_ptr + 1, scale_ptr + 1);
write_gemv_out(ptr_out + 2, out_ptr + 2, scale_ptr + 2);
write_gemv_out(ptr_out + 3, out_ptr + 3, scale_ptr + 3);
} }
//! deal with remains //! deal with remains
#pragma omp parallel for #pragma omp parallel for
...@@ -402,13 +407,11 @@ bool gemv_int8(const int8_t* A, ...@@ -402,13 +407,11 @@ bool gemv_int8(const int8_t* A,
const int8_t* ptr_in = data_in; const int8_t* ptr_in = data_in;
const int8_t* ptr_w0 = weights_ptr + (N * j); const int8_t* ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt; int cnt_loop = cnt;
int bias0 = is_bias ? bias[j] : 0; auto bias_ptr = is_bias ? bias + j : nullptr;
asm volatile( asm volatile(
"pld [%[in]] @ preload cache line, " "pld [%[in]] @ preload cache line, input\n"
"input\n"
"pld [%[w0]] @ preload cache line, weights r0\n" "pld [%[w0]] @ preload cache line, weights r0\n"
"vmov.u32 q0, #0 @ set q0 to 0\n" "vmov.u32 q0, #0 @ set q0 to 0\n"
"vmov s0, %[bias0] @ set q0 to bias0\n"
"cmp %[cnt], #1 @ check whether has main loop\n" "cmp %[cnt], #1 @ check whether has main loop\n"
"blt 2f @ jump to tail\n" "blt 2f @ jump to tail\n"
/* main loop */ /* main loop */
...@@ -429,50 +432,258 @@ bool gemv_int8(const int8_t* A, ...@@ -429,50 +432,258 @@ bool gemv_int8(const int8_t* A,
/* write output */ /* write output */
"vst1.32 {d0[0]}, [%[out]] @ save result\n" "vst1.32 {d0[0]}, [%[out]] @ save result\n"
: [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop) : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop)
: [bias0] "r"(bias0), [out] "r"(ptr_out) : [out] "r"(ptr_out)
: "cc", "memory", "q0", "q1", "q12", "q13"); : "cc", "memory", "q0", "q1", "q12", "q13");
for (int i = 0; i < tail; ++i) { for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i]; ptr_out[0] += ptr_in[i] * ptr_w0[i];
} }
if (is_relu) { write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu);
ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0; }
#endif // __aarch64__
return true;
}
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
template <typename dtype>
bool gemv_int8_sdot(const int8_t* A,
const int8_t* x,
dtype* y,
bool transA,
int M,
int N,
const float* scale,
bool is_bias,
const float* bias,
bool is_relu) {
if (transA) {
LOG(ERROR) << "ERROR: sgemv, transA is not supported now";
return false;
}
dtype* data_out = y;
const int8_t* data_in = x;
const int8_t* weights_ptr = A;
int cnt = N >> 4;
int tail = N & 15;
int size_m = (M >> 3) << 3;
#pragma omp parallel for
for (int j = 0; j < M - 7; j += 8) {
dtype* out_ptr = data_out + j;
const float* scale_ptr = scale + j;
auto bias_ptr = is_bias ? bias + j : nullptr;
int ptr_out[8] = {0, 0, 0, 0, 0, 0, 0, 0};
const int8_t* ptr_in = data_in;
const int8_t* ptr_w0 = weights_ptr + (N * j);
const int8_t* ptr_w1 = ptr_w0 + N;
const int8_t* ptr_w2 = ptr_w1 + N;
const int8_t* ptr_w3 = ptr_w2 + N;
const int8_t* ptr_w4 = ptr_w3 + N;
const int8_t* ptr_w5 = ptr_w4 + N;
const int8_t* ptr_w6 = ptr_w5 + N;
const int8_t* ptr_w7 = ptr_w6 + N;
int cnt_loop = cnt;
if (cnt > 0) {
asm volatile(
"prfm pldl1keep, [%[in]] \n" /* preload din */
"prfm pldl1keep, [%[w0]] \n" /* preload w0 */
"prfm pldl1keep, [%[w1]] \n" /* preload w1 */
"prfm pldl1keep, [%[w2]] \n" /* preload w2 */
"prfm pldl1keep, [%[w3]] \n" /* preload w3 */
"prfm pldl1keep, [%[w4]] \n" /* preload w4 */
"prfm pldl1keep, [%[w5]] \n" /* preload w5 */
"prfm pldl1keep, [%[w6]] \n" /* preload w6 */
"prfm pldl1keep, [%[w7]] \n" /* preload w7 */
"movi v0.4s, #0 \n" /* set out0 to 0 */
"movi v1.4s, #0 \n" /* set out1 to 0 */
"movi v2.4s, #0 \n" /* set out2 to 0 */
"movi v3.4s, #0 \n" /* set out3 to 0 */
"movi v4.4s, #0 \n" /* set out4 to 0 */
"movi v5.4s, #0 \n" /* set out5 to 0 */
"movi v6.4s, #0 \n" /* set out6 to 0 */
"movi v7.4s, #0 \n" /* set out7 to 0 */
/* main loop */
"1: \n" /* main loop */
"ldr q8, [%[in]], #16 \n" /* load input, 16 int8 */
"ldr q9, [%[w0]], #16 \n" /* load w0, 16 int8 */
"ldr q10, [%[w1]], #16 \n" /* load w0, 16 int8 */
"ldr q11, [%[w2]], #16 \n" /* load w0, 16 int8 */
"ldr q12, [%[w3]], #16 \n" /* load w0, 16 int8 */
"ldr q13, [%[w4]], #16 \n" /* load w0, 16 int8 */
"ldr q14, [%[w5]], #16 \n" /* load w0, 16 int8 */
"ldr q15, [%[w6]], #16 \n" /* load w0, 16 int8 */
"ldr q16, [%[w7]], #16 \n" /* load w0, 16 int8 */
".word 0x4e899500 // sdot v0.4s, v8.16b, v9.16b \n" /* out0, out1,
out2, out3
*/
".word 0x4e8a9501 // sdot v1.4s, v8.16b, v10.16b \n" /* out4, out5,
out6, out7
*/
".word 0x4e8b9502 // sdot v2.4s, v8.16b, v11.16b \n" /* out0, out1,
out2, out3
*/
".word 0x4e8c9503 // sdot v3.4s, v8.16b, v12.16b \n" /* out4, out5,
out6, out7
*/
"subs %w[cnt], %w[cnt], #1 \n"
".word 0x4e8d9504 // sdot v4.4s, v8.16b, v13.16b \n" /* out0, out1,
out2, out3
*/
".word 0x4e8e9505 // sdot v5.4s, v8.16b, v14.16b \n" /* out4, out5,
out6, out7
*/
".word 0x4e8f9506 // sdot v6.4s, v8.16b, v15.16b \n" /* out0, out1,
out2, out3
*/
".word 0x4e909507 // sdot v7.4s, v8.16b, v16.16b \n" /* out4, out5,
out6, out7
*/
"bne 1b \n" /* jump to main loop */
/* pair add to final result */
"2: \n" /* reduce to scale */
"addp v10.4s , v0.4s , v1.4s \n" /* pair add to 4 int32*/
"addp v11.4s , v2.4s , v3.4s \n" /* pair add to 4 int32*/
"addp v12.4s , v4.4s , v5.4s \n" /* pair add to 4 int32*/
"addp v13.4s , v6.4s , v7.4s \n" /* pair add to 4 int32*/
"addp v0.4s , v10.4s , v11.4s \n" /* pair add to 4 int32*/
"addp v1.4s , v12.4s , v13.4s \n" /* pair add to 4 int32*/
/* write to output */
"stp q0, q1, [%[out]] \n" /* save result */
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[w1] "+r"(ptr_w1),
[w2] "+r"(ptr_w2),
[w3] "+r"(ptr_w3),
[w4] "+r"(ptr_w4),
[w5] "+r"(ptr_w5),
[w6] "+r"(ptr_w6),
[w7] "+r"(ptr_w7),
[cnt] "+r"(cnt_loop)
: [out] "r"(ptr_out)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18");
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr); for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i];
ptr_out[1] += ptr_in[i] * ptr_w1[i];
ptr_out[2] += ptr_in[i] * ptr_w2[i];
ptr_out[3] += ptr_in[i] * ptr_w3[i];
ptr_out[4] += ptr_in[i] * ptr_w4[i];
ptr_out[5] += ptr_in[i] * ptr_w5[i];
ptr_out[6] += ptr_in[i] * ptr_w6[i];
ptr_out[7] += ptr_in[i] * ptr_w7[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, is_relu);
}
//! deal with remains
#pragma omp parallel for
for (int j = size_m; j < M; j++) {
// int *ptr_out = data_out + j;
dtype* out_ptr = data_out + j;
const float* scale_ptr = scale + j;
int ptr_out[1] = {0};
const int8_t* ptr_in = data_in;
const int8_t* ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt;
auto bias_ptr = is_bias ? bias + j : nullptr;
asm volatile(
"prfm pldl1keep, [%[in]] \n" /* preload din */
"prfm pldl1keep, [%[w0]] \n" /* preload w0 */
"cmp %w[cnt], #1 \n" /* check whether has main loop */
"movi v0.4s, #0 \n" /* set out0 to 0 */
/* check main loop */
"blt 2f \n" /* jump to tail */
/* main loop */
"1: \n" /* main loop */
"ldr q8, [%[in]], #16 \n" /* load input, 16 int8 */
"ldr q9, [%[w0]], #16 \n" /* load w0, 16 int8 */
"subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */
/* mul, lower 8 int8 * int8 = int16 */
".word 0x4e899500 // sdot v0.4s, v8.16b, v9.16b \n"
"bne 1b \n" /* jump to main loop */
/* pair add to final result */
"2: \n" /* reduce to scale */
"addp v1.4s, v0.4s, v0.4s \n" /* reduction to out0 */
"addp v2.4s, v1.4s, v1.4s \n" /* reduction to out0 */
/* write to output */
"str s2, [%[out]] \n" /* save result */
: [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop)
: [out] "r"(ptr_out)
: "cc", "memory", "v0", "v8", "v9", "v18");
for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu);
} }
#endif //__aarch64__ // NOLINT
return true; return true;
} }
#endif // __aarch64__ && sdot
template <>
bool gemv_int8<float>(const int8_t* A,
const int8_t* x,
float* y,
bool transA,
int M,
int N,
const float* scale,
bool is_bias,
const float* bias,
bool is_relu,
const ARMContext* ctx) {
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) {
gemv_int8_sdot<float>(A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
} else {
gemv_int8_oth<float>(A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
}
#else
gemv_int8_oth<float>(A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
#endif
}
template bool gemv_int8<float>(const int8_t* A, template <>
const int8_t* x, bool gemv_int8<int8_t>(const int8_t* A,
float* y, const int8_t* x,
bool transA, int8_t* y,
int M, bool transA,
int N, int M,
const float* scale, int N,
bool is_bias, const float* scale,
const int* bias, bool is_bias,
bool is_relu); const float* bias,
template bool gemv_int8<int>(const int8_t* A, bool is_relu,
const int8_t* x, const ARMContext* ctx) {
int* y, #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
bool transA, if (ctx->has_dot()) {
int M, gemv_int8_sdot<int8_t>(
int N, A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
const float* scale, } else {
bool is_bias, gemv_int8_oth<int8_t>(A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
const int* bias, }
bool is_relu); #else
template bool gemv_int8<signed char>(const int8_t* A, gemv_int8_oth<int8_t>(A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
const int8_t* x, #endif
signed char* y, }
bool transA,
int M,
int N,
const float* scale,
bool is_bias,
const int* bias,
bool is_relu);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include <cmath> #include <cmath>
#include "lite/core/device_info.h" #include "lite/core/context.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -30,9 +30,10 @@ bool gemv_int8(const int8_t* A, ...@@ -30,9 +30,10 @@ bool gemv_int8(const int8_t* A,
int M, int M,
int N, int N,
const float* scale, const float* scale,
bool is_bias = false, bool is_bias,
const int* bias = nullptr, const float* bias,
bool is_relu = false); bool is_relu,
const ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
...@@ -169,7 +169,7 @@ void prepackA(TensorLite *tout, ...@@ -169,7 +169,7 @@ void prepackA(TensorLite *tout,
int group, int group,
bool is_trans, bool is_trans,
ARMContext *ctx) { ARMContext *ctx) {
int hblock = get_hblock(ctx->arch()); int hblock = get_hblock(ctx);
int m_roundup = hblock * ((m + hblock - 1) / hblock); int m_roundup = hblock * ((m + hblock - 1) / hblock);
int group_size_round_up = ((m_roundup * k + 15) / 16) * 16; int group_size_round_up = ((m_roundup * k + 15) / 16) * 16;
if (tout->numel() < group_size_round_up * group) { if (tout->numel() < group_size_round_up * group) {
...@@ -1516,6 +1516,7 @@ void loadb_trans( ...@@ -1516,6 +1516,7 @@ void loadb_trans(
} }
} }
for (; x > 7; x -= 8) { for (; x > 7; x -= 8) {
// clang-format off
asm volatile( asm volatile(
"ldp q0, q1, [%[inptr0]], #32\n" /* r0, a0~a7 */ "ldp q0, q1, [%[inptr0]], #32\n" /* r0, a0~a7 */
"ldp q2, q3, [%[inptr1]], #32\n" /* r1, b0~b7 */ "ldp q2, q3, [%[inptr1]], #32\n" /* r1, b0~b7 */
...@@ -1638,40 +1639,12 @@ void loadb_trans( ...@@ -1638,40 +1639,12 @@ void loadb_trans(
[inptr11] "+r"(inptr11), [inptr11] "+r"(inptr11),
[outptr] "+r"(outptr) [outptr] "+r"(outptr)
: :
: "v0", : "v0","v1","v2","v3","v4","v5",
"v1", "v6","v7","v8","v9","v10","v11","v12",
"v2", "v13","v14","v15","v16","v17","v18","v19",
"v3", "v20","v21","v22","v23","v24","v25","v26",
"v4", "v27","v28","v29","v30","v31","cc","memory");
"v5", // clang-format on
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"v26",
"v27",
"v28",
"v29",
"v30",
"v31",
"cc",
"memory");
} }
for (; x > 0; x--) { for (; x > 0; x--) {
...@@ -2135,7 +2108,7 @@ void sgemm_prepacked_8x12(bool is_transB, ...@@ -2135,7 +2108,7 @@ void sgemm_prepacked_8x12(bool is_transB,
const float *a_ptr = a_ptr_l; const float *a_ptr = a_ptr_l;
int tail = tail_pre; int tail = tail_pre;
int k = k_pre; int k = k_pre;
// clang-format off
asm volatile( asm volatile(
"prfm pldl1keep, [%[a_ptr]]\n" /* preload a*/ "prfm pldl1keep, [%[a_ptr]]\n" /* preload a*/
"ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/ "ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/
...@@ -2596,40 +2569,13 @@ void sgemm_prepacked_8x12(bool is_transB, ...@@ -2596,40 +2569,13 @@ void sgemm_prepacked_8x12(bool is_transB,
[relu] "r"(has_relu), [relu] "r"(has_relu),
[has_beta] "r"(has_beta), [has_beta] "r"(has_beta),
[beta] "r"(beta) [beta] "r"(beta)
: "cc", : "cc","memory",
"memory", "v0","v1","v2","v3","v4","v5","v6","v7",
"v0", "v8","v9","v10","v11","v12","v13",
"v1", "v14","v15","v16","v17","v18","v19",
"v2", "v20","v21","v22","v23","v24","v25",
"v3", "v26","v27","v28","v29","v30","v31");
"v4", // clang-format on
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"v26",
"v27",
"v28",
"v29",
"v30",
"v31");
if (flag_p_remain && (xb == bblocks - 1)) { if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
*pout0++ = cout0[i]; *pout0++ = cout0[i];
...@@ -2799,6 +2745,7 @@ void sgemm_prepacked_6x8(bool is_transB, ...@@ -2799,6 +2745,7 @@ void sgemm_prepacked_6x8(bool is_transB,
const float* a_ptr = a_ptr_l; const float* a_ptr = a_ptr_l;
int tails = tail_pre; int tails = tail_pre;
int k = k_pre; int k = k_pre;
// clang-format off
asm volatile( asm volatile(
// sgemm 6x8 // sgemm 6x8
"vld1.32 {d2-d4}, [%[bias_ptr]] @ load bias 6 elements\n" "vld1.32 {d2-d4}, [%[bias_ptr]] @ load bias 6 elements\n"
...@@ -2826,7 +2773,7 @@ void sgemm_prepacked_6x8(bool is_transB, ...@@ -2826,7 +2773,7 @@ void sgemm_prepacked_6x8(bool is_transB,
"pld [%[b_ptr], #320] @ preload b\n" "pld [%[b_ptr], #320] @ preload b\n"
"vdup.i32 q11,d3[1] @ out31=0\n" "vdup.i32 q11,d3[1] @ out31=0\n"
"pld [%[b_ptr], #384] @ preload b\n" "pld [%[b_ptr], #384] @ preload b\n"
"cmp %[has_beta], #0\n" "cmp %[beta], #0\n"
"beq 11f\n" /* check beta == 0? */ "beq 11f\n" /* check beta == 0? */
/* process beta */ /* process beta */
"vdup.32 q3, %[beta]\n" /* beta to vector */ "vdup.32 q3, %[beta]\n" /* beta to vector */
...@@ -3082,26 +3029,11 @@ void sgemm_prepacked_6x8(bool is_transB, ...@@ -3082,26 +3029,11 @@ void sgemm_prepacked_6x8(bool is_transB,
[tails] "+r"(tails) [tails] "+r"(tails)
: [bias_ptr] "r"(bias_local), : [bias_ptr] "r"(bias_local),
[relu] "r"(has_relu), [relu] "r"(has_relu),
[has_beta] "r"(has_beta),
[beta] "r"(beta) [beta] "r"(beta)
: "q0", : "q0","q1","q2","q3","q4",
"q1", "q5","q6","q7","q8","q9","q10","q11",
"q2", "q12","q13","q14","q15","cc","memory");
"q3", // clang-format on
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"cc",
"memory");
if (flag_p_remain && (xb == bblocks - 1)) { if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
...@@ -3243,6 +3175,7 @@ void sgemm_prepacked_4x8(bool is_transB, ...@@ -3243,6 +3175,7 @@ void sgemm_prepacked_4x8(bool is_transB,
const float* a_ptr = a_ptr_l; const float* a_ptr = a_ptr_l;
int tails = tail_pre; int tails = tail_pre;
int k = k_pre; int k = k_pre;
// clang-format off
asm volatile( asm volatile(
"vld1.32 {d4-d5}, [%[bias_ptr]] @ load bias\n" "vld1.32 {d4-d5}, [%[bias_ptr]] @ load bias\n"
"vdup.32 q8, d4[0] @ add bias to out00\n" "vdup.32 q8, d4[0] @ add bias to out00\n"
...@@ -3260,7 +3193,7 @@ void sgemm_prepacked_4x8(bool is_transB, ...@@ -3260,7 +3193,7 @@ void sgemm_prepacked_4x8(bool is_transB,
"pld [%[b_ptr], #128] @ preload b\n" "pld [%[b_ptr], #128] @ preload b\n"
"vdup.32 q15, d5[1] @ add bias to out31\n" "vdup.32 q15, d5[1] @ add bias to out31\n"
"pld [%[b_ptr], #192] @ preload b\n" "pld [%[b_ptr], #192] @ preload b\n"
"cmp %[has_beta], #0\n" "cmp %[beta], #0\n"
"beq 11f\n" /* check beta == 0? */ "beq 11f\n" /* check beta == 0? */
/* process beta */ /* process beta */
"vdup.32 q4, %[beta]\n" /* beta to vector */ "vdup.32 q4, %[beta]\n" /* beta to vector */
...@@ -3440,27 +3373,11 @@ void sgemm_prepacked_4x8(bool is_transB, ...@@ -3440,27 +3373,11 @@ void sgemm_prepacked_4x8(bool is_transB,
[tails] "+r"(tails) [tails] "+r"(tails)
: [bias_ptr] "r"(bias_local), : [bias_ptr] "r"(bias_local),
[relu] "r"(has_relu), [relu] "r"(has_relu),
[has_beta] "r"(has_beta),
[beta] "r"(beta) [beta] "r"(beta)
: "q0", : "q0","q1","q2","q3",
"q1", "q4","q5","q6","q7","q8","q9","q10",
"q2", "q11","q12","q13","q14","q15","cc","memory");
"q3", // clang-format on
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"cc",
"memory");
if (flag_p_remain && (xb == bblocks - 1)) { if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
*pout0++ = cout0[i]; *pout0++ = cout0[i];
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <cmath> #include <cmath>
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/device_info.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
namespace paddle { namespace paddle {
...@@ -28,14 +27,14 @@ namespace math { ...@@ -28,14 +27,14 @@ namespace math {
constexpr int MBLOCK = 8; constexpr int MBLOCK = 8;
constexpr int NBLOCK = 12; constexpr int NBLOCK = 12;
constexpr int KBLOCK = 4; constexpr int KBLOCK = 4;
inline int get_hblock(ARMArch arch) { return MBLOCK; } inline int get_hblock(ARMContext* ctx) { return MBLOCK; }
#else #else
constexpr int MBLOCK_A73 = 4; constexpr int MBLOCK_A73 = 4;
constexpr int MBLOCK_OTH = 6; constexpr int MBLOCK_OTH = 6;
constexpr int NBLOCK = 8; constexpr int NBLOCK = 8;
constexpr int KBLOCK = 4; constexpr int KBLOCK = 4;
inline int get_hblock(ARMArch arch) { inline int get_hblock(ARMContext* ctx) {
if (arch == kA73) { if (ctx->arch() == kA73) {
return MBLOCK_A73; return MBLOCK_A73;
} else { } else {
return MBLOCK_OTH; return MBLOCK_OTH;
......
...@@ -36,8 +36,7 @@ void sgemm(bool is_transA, ...@@ -36,8 +36,7 @@ void sgemm(bool is_transA,
bool is_bias, bool is_bias,
bool is_relu, bool is_relu,
ARMContext* ctx) { ARMContext* ctx) {
auto arch = ctx->arch(); int hblock = get_hblock(ctx);
int hblock = get_hblock(arch);
int m_roundup = hblock * ((M + hblock - 1) / hblock); int m_roundup = hblock * ((M + hblock - 1) / hblock);
auto packed_A = static_cast<float*>( auto packed_A = static_cast<float*>(
......
...@@ -43,20 +43,25 @@ class KernelBase { ...@@ -43,20 +43,25 @@ class KernelBase {
const std::map<std::string, const Type*>& input_types, const std::map<std::string, const Type*>& input_types,
const std::string& out_arg)>; const std::string& out_arg)>;
protected:
/// Run some initialization before `Run`, it will invoke after `SetParam` and /// Run some initialization before `Run`, it will invoke after `SetParam` and
/// `SetContext`, that is both the param_ and context_ are valid. /// `SetContext`, that is both the param_ and context_ are valid.
virtual void PrepareForRun() {} virtual void PrepareForRun() {}
/// Run kernel initialization if needed at every run (eg. input shape changed)
virtual void ReInitWhenNeeded() {}
/// Run the kernel. Before Run, both the param_ and context_ should be valid. /// Run the kernel. Before Run, both the param_ and context_ should be valid.
virtual void Run() = 0; virtual void Run() = 0;
public:
void Launch() { void Launch() {
/// First run, init kernel, do weights transform once
if (is_first_epoch_) { if (is_first_epoch_) {
PrepareForRun(); PrepareForRun();
is_first_epoch_ = false; is_first_epoch_ = false;
} }
/// re-init the kernel if needed (input shape should be checked in conv
/// kernel)
ReInitWhenNeeded();
// Reset the workspace to make every kernel in the same thread to share the // Reset the workspace to make every kernel in the same thread to share the
// temporary memory. // temporary memory.
......
# for conv op
add_kernel(conv_depthwise ARM basic SRCS conv_depthwise.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(conv_direct ARM basic SRCS conv_direct.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(conv_gemmlike ARM basic SRCS conv_gemmlike.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(conv_winograd ARM basic SRCS conv_winograd.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(conv_compute_arm ARM basic SRCS conv_compute.cc DEPS ${lite_kernel_deps}
conv_depthwise conv_direct conv_gemmlike conv_winograd)
add_kernel(fc_compute_arm ARM basic SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(fc_compute_arm ARM basic SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(activation_compute_arm ARM basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(activation_compute_arm ARM basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(mul_compute_arm ARM basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mul_compute_arm ARM basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(matmul_compute_arm ARM basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(matmul_compute_arm ARM basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(scale_compute_arm ARM basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(scale_compute_arm ARM basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(softmax_compute_arm ARM basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(softmax_compute_arm ARM basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(conv_compute_arm ARM basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(batch_norm_compute_arm ARM basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(batch_norm_compute_arm ARM basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(elementwise_compute_arm ARM basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(elementwise_compute_arm ARM basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lrn_compute_arm ARM basic SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lrn_compute_arm ARM basic SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -77,11 +84,8 @@ endif() ...@@ -77,11 +84,8 @@ endif()
message(STATUS "compile with lite ARM kernels") message(STATUS "compile with lite ARM kernels")
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm)
lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm) lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm)
lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm) lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm)
lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm) lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm)
......
...@@ -13,101 +13,80 @@ ...@@ -13,101 +13,80 @@
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/conv_compute.h" #include "lite/kernels/arm/conv_compute.h"
#include <utility>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/type_system.h" #include "lite/core/type_system.h"
#include "lite/kernels/arm/conv_depthwise.h"
#include "lite/kernels/arm/conv_direct.h"
#include "lite/kernels/arm/conv_gemmlike.h"
#include "lite/kernels/arm/conv_winograd.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void ConvCompute::PrepareForRun() { template <>
void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
int win = x_dims[3]; // nchw int ic = w_dims[1] * param.groups;
int hin = x_dims[2]; int oc = w_dims[0];
int ic = x_dims[1];
int bs = x_dims[0];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int kh = w_dims[2]; // oihw int kh = w_dims[2]; // oihw
int kw = w_dims[3]; int kw = w_dims[3];
int pad = param.paddings[0]; int pad = param.paddings[0];
int stride = param.strides[0]; int stride = param.strides[0];
const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>();
bool kps_equal = (param.paddings[0] == param.paddings[1]) && bool kps_equal = (param.paddings[0] == param.paddings[1]) &&
(param.strides[0] == param.strides[1]) && (kw == kh); (param.strides[0] == param.strides[1]) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2));
(kw == 3 && (pad == 0 || pad == 1) && (stride == 1 || stride == 2));
bool flag_dw_5x5 = bool flag_dw_5x5 =
(kw == 5 && stride == 1) || (kw == 5 && stride == 2 && pad == 2); (kw == 5 && stride == 1) || (kw == 5 && stride == 2 && pad == 2);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
// select conv impl /// select conv impl
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
// dw conv impl /// dw conv impl
impl_ = new lite::arm::math::DepthwiseConv<PRECISION(kFloat)>; impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking dw conv"; VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) { no_dilation) {
if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) { if (ic >= 32 && oc >= 32) {
// winograd conv impl /// winograd conv impl
impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>; impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv"; VLOG(3) << "invoking winograd conv";
} else { } else {
// direct conv impl /// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv"; VLOG(3) << "invoking direct conv";
} }
} else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal &&
no_dilation) { no_dilation) {
// direct conv impl /// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv"; VLOG(3) << "invoking direct conv";
} else { } else {
impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>; impl_ = new GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking gemm like conv"; VLOG(3) << "invoking gemm like conv";
} }
CHECK(this->impl_->create(param, &ctx)); impl_->SetContext(std::move(this->ctx_));
} impl_->SetParam(param);
impl_->PrepareForRun();
void ConvCompute::Run() { is_first_epoch_ = false;
auto& param = this->Param<param_t>();
CHECK(impl_);
impl_->run(param);
// if (this->act_ != nullptr) {
// this->act_->run(outputs, outputs, param.activation_param);
// }
} }
template <PrecisionType Ptype_out> template <>
void ConvComputeInt8<Ptype_out>::PrepareForRun() { void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
int win = x_dims[3]; // nchw int ic = param.groups * w_dims[1];
int hin = x_dims[2]; int oc = w_dims[0];
int ic = x_dims[1];
int bs = x_dims[0];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int kh = w_dims[2]; // oihw int kh = w_dims[2]; // oihw
int kw = w_dims[3]; int kw = w_dims[3];
int ph = param.paddings[1]; int ph = param.paddings[1];
...@@ -115,78 +94,98 @@ void ConvComputeInt8<Ptype_out>::PrepareForRun() { ...@@ -115,78 +94,98 @@ void ConvComputeInt8<Ptype_out>::PrepareForRun() {
int sh = param.strides[1]; int sh = param.strides[1];
int sw = param.strides[0]; int sw = param.strides[0];
bool with_bias = param.bias;
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3) && (ph == 1) && (sw == 1 || sw == 2); bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2); bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
impl_ = new lite::arm::math::DepthwiseConvInt8<Ptype_out>; impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>;
VLOG(3) << "Run DepthwiseConv Int8"; VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) { kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>;
VLOG(3) << "Run DirectConv Int8"; VLOG(3) << "Run DirectConv Int8";
impl_ = new lite::arm::math::DirectConvInt8<Ptype_out>;
} else { } else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>;
VLOG(3) << "Run GemmLikeConvInt8"; VLOG(3) << "Run GemmLikeConvInt8";
impl_ = new lite::arm::math::GemmLikeConvInt8<Ptype_out>;
} }
// Convert fp32 bias to int32 bias. impl_->SetContext(std::move(this->ctx_));
if (with_bias) { impl_->SetParam(param);
Tensor temp_tensor; impl_->PrepareForRun();
temp_tensor.CopyDataFrom(*param.bias); is_first_epoch_ = false;
lite::arm::math::trans_fp32_bias_to_int32_basic(
&temp_tensor, param.bias, param.input_scale, param.weight_scale);
}
// param.bias->data<int32_t>();
CHECK(this->impl_->create(param, &ctx));
} }
template <PrecisionType Ptype_out> template <>
void ConvComputeInt8<Ptype_out>::Run() { void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
CHECK(impl_); auto w_dims = param.filter->dims();
impl_->run(param);
} auto& ctx = this->ctx_->template As<ARMContext>();
int ic = w_dims[1] * param.groups;
int oc = w_dims[0];
int kh = w_dims[2]; // oihw
int kw = w_dims[3];
int ph = param.paddings[1];
int pw = param.paddings[0];
int sh = param.strides[1];
int sw = param.strides[0];
template class ConvComputeInt8<PRECISION(kInt8)>; bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
template class ConvComputeInt8<PRECISION(kFloat)>; bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
template class ConvComputeInt8<PRECISION(kInt32)>; bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>;
VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>;
VLOG(3) << "Run DirectConv Int8";
} else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>;
VLOG(3) << "Run GemmLikeConvInt8";
}
impl_->SetContext(std::move(this->ctx_));
impl_->SetParam(param);
impl_->PrepareForRun();
is_first_epoch_ = false;
}
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( typedef paddle::lite::kernels::arm::ConvCompute<PRECISION(kFloat),
conv2d, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConvCompute, def) PRECISION(kFloat)>
ConvFp32;
typedef paddle::lite::kernels::arm::ConvCompute<PRECISION(kInt8),
PRECISION(kFloat)>
ConvInt8_Fp32;
typedef paddle::lite::kernels::arm::ConvCompute<PRECISION(kInt8),
PRECISION(kInt8)>
ConvInt8_Int8;
REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, ConvFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, ConvFp32, def)
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ConvCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, ConvInt8_Int8, int8_out)
conv2d,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kInt8)>,
int8_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter", .BindInput("Filter",
...@@ -195,13 +194,7 @@ REGISTER_LITE_KERNEL( ...@@ -195,13 +194,7 @@ REGISTER_LITE_KERNEL(
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, ConvInt8_Fp32, fp32_out)
conv2d,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kFloat)>,
fp32_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter", .BindInput("Filter",
...@@ -211,12 +204,7 @@ REGISTER_LITE_KERNEL( ...@@ -211,12 +204,7 @@ REGISTER_LITE_KERNEL(
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
depthwise_conv2d, depthwise_conv2d, kARM, kInt8, kNCHW, ConvInt8_Int8, int8_out)
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kInt8)>,
int8_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter", .BindInput("Filter",
...@@ -226,12 +214,7 @@ REGISTER_LITE_KERNEL( ...@@ -226,12 +214,7 @@ REGISTER_LITE_KERNEL(
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
depthwise_conv2d, depthwise_conv2d, kARM, kInt8, kNCHW, ConvInt8_Fp32, fp32_out)
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8<PRECISION(kFloat)>,
fp32_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter", .BindInput("Filter",
......
...@@ -15,20 +15,26 @@ ...@@ -15,20 +15,26 @@
#pragma once #pragma once
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/operators/conv_op.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { template <PrecisionType Ptype, PrecisionType OutType>
class ConvCompute : public KernelLite<TARGET(kARM), Ptype> {
public: public:
using param_t = operators::ConvParam; virtual void PrepareForRun();
void PrepareForRun() override; virtual void ReInitWhenNeeded() {
CHECK(impl_);
impl_->ReInitWhenNeeded();
}
void Run() override; virtual void Run() {
CHECK(impl_);
impl_->Run();
}
~ConvCompute() { ~ConvCompute() {
if (impl_ != nullptr) { if (impl_ != nullptr) {
...@@ -37,28 +43,8 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -37,28 +43,8 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
} }
private: private:
lite::arm::math::ImplBase<TARGET(kARM), PRECISION(kFloat), param_t>* impl_{
nullptr};
};
template <PrecisionType Ptype_out>
class ConvComputeInt8 : public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::ConvParam; using param_t = operators::ConvParam;
KernelLite<TARGET(kARM), Ptype>* impl_{nullptr};
void PrepareForRun() override;
void Run() override;
~ConvComputeInt8() {
if (impl_ != nullptr) {
delete impl_;
}
}
private:
lite::arm::math::ImplBase<TARGET(kARM), PRECISION(kInt8), param_t>* impl_{
nullptr};
}; };
} // namespace arm } // namespace arm
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/conv_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
static int get_rand(int start, int end) {
int i = rand(); // NOLINT
i = (i % (end - start)) + start;
return i;
}
template <typename Dtype1, typename Dtype2>
static void conv_basic(const Dtype1* din,
Dtype2* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const Dtype1* weights,
const Dtype2* bias,
int group,
int kernel_w,
int kernel_h,
int stride_w,
int stride_h,
int dila_w,
int dila_h,
int pad_w,
int pad_h,
bool flag_bias,
bool flag_relu) {
Dtype2 beta = 0;
auto src_data = din;
auto dst_data_ref = dout;
auto weights_data = weights;
auto with_bias = flag_bias;
auto bias_data = bias;
int in_num = num;
int out_channels = chout;
int out_h = hout;
int out_w = wout;
int in_channel = chin;
int in_h = hin;
int in_w = win;
int out_c_group = out_channels / group;
int in_c_group = in_channel / group;
for (int n = 0; n < in_num; ++n) {
for (int g = 0; g < group; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < out_h; ++oh) {
for (int ow = 0; ow < out_w; ++ow) {
int out_idx = n * group * out_c_group * out_h * out_w +
g * out_c_group * out_h * out_w + oc * out_h * out_w +
oh * out_w + ow;
Dtype2 bias_d =
with_bias ? (bias_data[g * out_c_group + oc]) : (Dtype2)0;
dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - pad_w + kw * (dila_w);
int ih = oh * stride_h - pad_h + kh * (dila_h);
if (iw < 0 || iw >= in_w) continue;
if (ih < 0 || ih >= in_h) continue;
int iidx = n * in_channel * in_h * in_w +
g * in_c_group * in_h * in_w + ic * in_h * in_w +
ih * in_w + iw;
int widx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx];
}
}
}
if (flag_relu) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
}
}
}
}
}
}
}
template <typename Dtype1, typename Dtype2>
void conv_compute_ref(const operators::ConvParam& param) {
const Dtype1* din = param.x->data<Dtype1>();
Dtype2* dout = param.output->mutable_data<Dtype2>();
int num = param.x->dims()[0];
int chout = param.output->dims()[1];
int hout = param.output->dims()[2];
int wout = param.output->dims()[3];
int chin = param.x->dims()[1];
int hin = param.x->dims()[2];
int win = param.x->dims()[3];
const Dtype1* weights = param.filter->mutable_data<Dtype1>();
Dtype2* bias = nullptr;
if (param.bias != nullptr) {
bias = param.bias->mutable_data<Dtype2>();
}
int group = param.groups;
int kernel_w = param.filter->dims()[2];
int kernel_h = param.filter->dims()[3];
int stride_w = param.strides[0];
int stride_h = param.strides[1];
int dila_w = param.dilations[0];
int dila_h = param.dilations[1];
int pad_w = param.paddings[0];
int pad_h = param.paddings[1];
bool flag_bias = (param.bias != nullptr);
bool flag_relu = param.fuse_relu;
conv_basic(din,
dout,
num,
chout,
hout,
wout,
chin,
hin,
win,
weights,
bias,
group,
kernel_w,
kernel_h,
stride_w,
stride_h,
dila_w,
dila_h,
pad_w,
pad_h,
flag_bias,
flag_relu);
}
TEST(conv_arm, retrive_op) {
auto conv = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"conv2d");
ASSERT_FALSE(conv.empty());
ASSERT_TRUE(conv.front());
}
TEST(conv_arm_int8, retrive_op) {
auto conv =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kInt8)>("conv2d");
ASSERT_FALSE(conv.empty());
ASSERT_TRUE(conv.front());
}
TEST(conv_arm, init) {
ConvCompute conv;
ASSERT_EQ(conv.precision(), PRECISION(kFloat));
ASSERT_EQ(conv.target(), TARGET(kARM));
}
TEST(conv_arm_int8, init) {
ConvComputeInt8<PRECISION(kFloat)> float_out;
ASSERT_EQ(float_out.precision(), PRECISION(kInt8));
ASSERT_EQ(float_out.target(), TARGET(kARM));
ConvComputeInt8<PRECISION(kInt8)> int8_out;
ASSERT_EQ(float_out.precision(), PRECISION(kInt8));
ASSERT_EQ(float_out.target(), TARGET(kARM));
}
TEST(conv_arm_int8, int8_int32) {
DeviceInfo::Init();
for (auto n : {2}) {
for (auto ic : {6}) {
for (auto oc : {6}) {
for (auto ih : {9}) {
for (auto iw : {9}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, /*true*/}) {
for (auto dilation : {1}) {
for (auto stride : {1}) {
for (auto padding : {0}) {
for (auto ks : {1}) {
int group = 1;
if (depthwise) { // depthwise convolution ?
group = oc = ic;
}
const int dks = dilation * (ks - 1) + 1;
int oh = (ih + 2 * padding - dks) / stride + 1;
int ow = (iw + 2 * padding - dks) / stride + 1;
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {
oc, ic / group, ks, ks};
std::vector<int64_t> output_shape({n, oc, oh, ow});
Tensor input_int8;
Tensor filter_int8;
Tensor output_int32, output_int32_ref;
input_int8.Resize(input_shape);
filter_int8.Resize(filter_shape);
output_int32.Resize(output_shape);
output_int32_ref.Resize(output_shape);
int8_t* input_int8_data =
input_int8.mutable_data<int8_t>();
int8_t* filter_int8_data =
filter_int8.mutable_data<int8_t>();
for (int i = 0; i < input_int8.dims().production();
i++) {
input_int8_data[i] = i % 10 * (i % 3 - 1);
}
for (int i = 0; i < filter_int8.dims().production();
i++) {
filter_int8_data[i] = i % 10 * (i % 3 - 1);
}
operators::ConvParam param;
param.x = &input_int8;
param.filter = &filter_int8;
param.bias = nullptr;
param.fuse_relu = false;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
std::vector<int>({dilation, dilation});
param.groups = group;
param.output = &output_int32_ref;
conv_compute_ref<int8_t, int>(param);
param.output = &output_int32;
std::unique_ptr<KernelContext> ctx(new KernelContext);
lite::arm::math::GemmLikeConvInt8<PRECISION(kInt32)>
int8gemm_int32;
int8gemm_int32.init(param, &ctx->As<ARMContext>());
int8gemm_int32.create(param, &ctx->As<ARMContext>());
int8gemm_int32.run(param);
int* output_int32_data =
output_int32.mutable_data<int>();
int* output_int32_ref_data =
output_int32_ref.mutable_data<int>();
for (int i = 0; i < output_int32.dims().production();
i++) {
EXPECT_NEAR(output_int32_data[i],
output_int32_ref_data[i],
1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
TEST(conv_arm_int8, int8_fp32) {
DeviceInfo::Init();
for (auto n : {2}) {
for (auto ic : {6}) {
for (auto oc : {6}) {
for (auto ih : {9}) {
for (auto iw : {9}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, /*true*/}) {
for (auto dilation : {1}) {
for (auto stride : {1}) {
for (auto padding : {0}) {
for (auto ks : {1}) {
int group = 1;
if (depthwise) { // depthwise convolution ?
group = oc = ic;
}
LOG(INFO) << "flag_bias: " << flag_bias;
const int dks = dilation * (ks - 1) + 1;
int oh = (ih + 2 * padding - dks) / stride + 1;
int ow = (iw + 2 * padding - dks) / stride + 1;
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {
oc, ic / group, ks, ks};
std::vector<int64_t> bias_shape({1, oc, 1, 1});
std::vector<int64_t> output_shape({n, oc, oh, ow});
Tensor input_fp32, input_int8;
Tensor filter_fp32, filter_int8;
Tensor bias_fp32, bias_int32;
Tensor output_int32_ref, output_int32;
Tensor output_fp32_ref, output_fp32;
Tensor output_int8_ref, output_int8;
input_fp32.Resize(input_shape);
input_int8.Resize(input_shape);
filter_fp32.Resize(filter_shape);
filter_int8.Resize(filter_shape);
bias_fp32.Resize(bias_shape);
bias_int32.Resize(bias_shape);
output_int32.Resize(output_shape);
output_int32_ref.Resize(output_shape);
output_fp32_ref.Resize(output_shape);
output_fp32.Resize(output_shape);
output_int8_ref.Resize(output_shape);
output_int8.Resize(output_shape);
float* input_fp32_data =
input_fp32.mutable_data<float>();
int8_t* input_int8_data =
input_int8.mutable_data<int8_t>();
float* filter_fp32_data =
filter_fp32.mutable_data<float>();
int8_t* filter_int8_data =
filter_int8.mutable_data<int8_t>();
float* bias_fp32_data =
bias_fp32.mutable_data<float>();
int* bias_int32_data = bias_int32.mutable_data<int>();
for (int i = 0; i < input_fp32.dims().production();
i++) {
input_fp32_data[i] = i % 10 * (i % 3 - 1);
}
for (int i = 0; i < filter_fp32.dims().production();
i++) {
filter_fp32_data[i] = i % 10 * (i % 3 - 1);
}
for (int i = 0; i < bias_fp32.dims().production();
i++) {
bias_fp32_data[i] = i % 10 * (i % 3 - 1);
}
std::vector<float> in_scale;
lite::arm::math::get_tensor_scale<PRECISION(kFloat)>(
input_fp32, &in_scale, -1, 127.f);
lite::arm::math::trans_tensor_fp32_to_int8(
&input_fp32, &input_int8, in_scale[0]);
std::vector<float> w_scale;
lite::arm::math::get_tensor_scale<PRECISION(kFloat)>(
filter_fp32, &w_scale, -1, 127.f);
int axis_size = oc;
int inner_size = ic / group * ks * ks;
w_scale = lite::arm::math::get_tensor_scale_n(
filter_fp32_data, axis_size, inner_size, 127.f);
lite::arm::math::fp32_to_int8(filter_fp32_data,
filter_int8_data,
w_scale.data(),
axis_size,
1,
inner_size);
// lite::arm::math::trans_fp32_bias_to_int32_basic(&bias_fp32,
// &bias_int32, in_scale[0], w_scale);
for (int i = 0; i < bias_int32.dims().production();
i++) {
bias_int32_data[i] = 1;
}
operators::ConvParam param;
param.x = &input_int8;
param.filter = &filter_int8;
if (flag_bias) {
param.bias = &bias_int32;
} else {
param.bias = nullptr;
}
param.fuse_relu = false;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
std::vector<int>({dilation, dilation});
param.groups = group;
param.output = &output_int32_ref;
conv_compute_ref<int8_t, int>(param);
int* output_int32_ref_data =
output_int32_ref.mutable_data<int>();
// ============ int8gemm_int32 ============
/*
param.output = &output_int32;
std::unique_ptr<KernelContext> ctx_int32(
new KernelContext);
lite::arm::math::GemmLikeConvInt8<PRECISION(kInt32)>
int8gemm_int32;
int8gemm_int32.init(param,
&ctx_int32->As<ARMContext>());
int8gemm_int32.create(param,
&ctx_int32->As<ARMContext>());
int8gemm_int32.run(param);
int* output_int32_data =
output_int32.mutable_data<int>();
for (int i = 0; i < output_int32.dims().production();
i++) {
EXPECT_NEAR(output_int32_data[i],
output_int32_ref_data[i], 1e-3);
}
*/
// ============ int8gemm_int8 ============
int8_t* output_int8_ref_data =
output_int8_ref.mutable_data<int8_t>();
lite::arm::math::trans_tensor_int32_to_int8(
&output_int32_ref,
&output_int8_ref,
in_scale[0],
1,
w_scale);
param.output = &output_int8;
param.input_scale = in_scale[0];
param.output_scale = 1;
param.weight_scale = w_scale;
std::unique_ptr<KernelContext> ctx_int8(
new KernelContext);
lite::arm::math::GemmLikeConvInt8<PRECISION(kInt8)>
int8gemm_int8;
int8gemm_int8.init(param,
&ctx_int8->As<ARMContext>());
int8gemm_int8.create(param,
&ctx_int8->As<ARMContext>());
int8gemm_int8.run(param);
int8_t* output_int8_data =
output_int8.mutable_data<int8_t>();
for (int i = 0; i < output_int8.dims().production();
i++) {
EXPECT_NEAR(output_int8_data[i],
output_int8_ref_data[i],
1e-3);
}
// ============ int8gemm_float32 ============
float* output_fp32_ref_data =
output_fp32_ref.mutable_data<float>();
lite::arm::math::trans_tensor_int32_to_fp32(
&output_int32_ref,
&output_fp32_ref,
in_scale[0],
w_scale);
param.output = &output_fp32;
param.input_scale = in_scale[0];
param.output_scale = 1;
param.weight_scale = w_scale;
std::unique_ptr<KernelContext> ctx_fp32(
new KernelContext);
lite::arm::math::GemmLikeConvInt8<PRECISION(kFloat)>
int8gemm_fp32;
int8gemm_fp32.init(param,
&ctx_fp32->As<ARMContext>());
int8gemm_fp32.create(param,
&ctx_fp32->As<ARMContext>());
int8gemm_fp32.run(param);
float* output_fp32_data =
output_fp32.mutable_data<float>();
for (int i = 0; i < output_fp32.dims().production();
i++) {
EXPECT_NEAR(output_fp32_data[i],
output_fp32_ref_data[i],
1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
TEST(conv_direct_int8, compute) {
DeviceInfo::Init();
for (auto n : {1, 2}) {
for (auto ic : {1, 3, 8}) {
for (auto oc : {1, 3, 8}) {
for (auto ih : {5, 15, 28}) {
for (auto iw : {5, 15, 28}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, /*true*/}) {
for (auto dilation : {1}) {
for (auto stride : {1, 2}) {
for (auto padding : {1}) {
for (auto ks : {3}) {
int group = 1;
if (depthwise) { // depthwise convolution ?
group = oc = ic;
}
const int dks = dilation * (ks - 1) + 1;
int oh = (ih + 2 * padding - dks) / stride + 1;
int ow = (iw + 2 * padding - dks) / stride + 1;
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {
oc, ic / group, ks, ks};
std::vector<int64_t> bias_shape({1, oc, 1, 1});
std::vector<int64_t> output_shape({n, oc, oh, ow});
Tensor input_fp32, input_int8;
Tensor filter_fp32, filter_int8;
Tensor bias_int32;
Tensor output_int32_ref, output_int32;
Tensor output_fp32_ref, output_fp32;
Tensor output_int8_ref, output_int8;
input_fp32.Resize(input_shape);
input_int8.Resize(input_shape);
filter_fp32.Resize(filter_shape);
filter_int8.Resize(filter_shape);
bias_int32.Resize(bias_shape);
output_int32.Resize(output_shape);
output_int32_ref.Resize(output_shape);
output_fp32_ref.Resize(output_shape);
output_fp32.Resize(output_shape);
output_int8_ref.Resize(output_shape);
output_int8.Resize(output_shape);
float* input_fp32_data =
input_fp32.mutable_data<float>();
int8_t* input_int8_data =
input_int8.mutable_data<int8_t>();
float* filter_fp32_data =
filter_fp32.mutable_data<float>();
int8_t* filter_int8_data =
filter_int8.mutable_data<int8_t>();
int* bias_int32_data =
bias_int32.mutable_data<int32_t>();
for (int i = 0; i < input_fp32.dims().production();
i++) {
input_fp32_data[i] = i % 10 * (i % 3 - 1);
}
for (int i = 0; i < filter_fp32.dims().production();
i++) {
filter_fp32_data[i] = i % 10 * (i % 3 - 1);
}
for (int i = 0; i < bias_int32.dims().production();
i++) {
bias_int32_data[i] = i % 10 * (i % 3 - 1);
}
std::vector<float> in_scale;
lite::arm::math::get_tensor_scale<PRECISION(kFloat)>(
input_fp32, &in_scale, -1, 127.f);
lite::arm::math::trans_tensor_fp32_to_int8(
&input_fp32, &input_int8, in_scale[0]);
std::vector<float> w_scale;
lite::arm::math::get_tensor_scale<PRECISION(kFloat)>(
filter_fp32, &w_scale, -1, 127.f);
int axis_size = oc;
int inner_size = ic / group * ks * ks;
w_scale = lite::arm::math::get_tensor_scale_n(
filter_fp32_data, axis_size, inner_size, 127.f);
lite::arm::math::fp32_to_int8(filter_fp32_data,
filter_int8_data,
w_scale.data(),
axis_size,
1,
inner_size);
operators::ConvParam param;
param.x = &input_int8;
param.filter = &filter_int8;
if (flag_bias) {
param.bias = &bias_int32;
}
param.fuse_relu = false;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
std::vector<int>({dilation, dilation});
param.groups = group;
param.output = &output_int32_ref;
conv_compute_ref<int8_t, int>(param);
int* output_int32_ref_data =
output_int32_ref.mutable_data<int>();
// ============ int8direct_int32 ============
param.output = &output_int32;
std::unique_ptr<KernelContext> ctx_int32(
new KernelContext);
lite::arm::math::DirectConvInt8<PRECISION(kInt32)>
int8direct_int32;
int8direct_int32.init(param,
&ctx_int32->As<ARMContext>());
int8direct_int32.create(param,
&ctx_int32->As<ARMContext>());
int8direct_int32.run(param);
int* output_int32_data =
output_int32.mutable_data<int>();
for (int i = 0; i < output_int32.dims().production();
i++) {
EXPECT_NEAR(output_int32_data[i],
output_int32_ref_data[i],
1e-3);
}
// ============ int8direct_int8 ============
int8_t* output_int8_ref_data =
output_int8_ref.mutable_data<int8_t>();
lite::arm::math::trans_tensor_int32_to_int8(
&output_int32_ref,
&output_int8_ref,
in_scale[0],
1,
w_scale);
param.output = &output_int8;
param.input_scale = in_scale[0];
param.output_scale = 1;
param.weight_scale = w_scale;
std::unique_ptr<KernelContext> ctx_int8(
new KernelContext);
lite::arm::math::DirectConvInt8<PRECISION(kInt8)>
int8direct_int8;
int8direct_int8.init(param,
&ctx_int8->As<ARMContext>());
int8direct_int8.create(param,
&ctx_int8->As<ARMContext>());
int8direct_int8.run(param);
int8_t* output_int8_data =
output_int8.mutable_data<int8_t>();
for (int i = 0; i < output_int8.dims().production();
i++) {
EXPECT_NEAR(output_int8_data[i],
output_int8_ref_data[i],
1e-3);
}
// ============ int8direct_float32 ============
float* output_fp32_ref_data =
output_fp32_ref.mutable_data<float>();
lite::arm::math::trans_tensor_int32_to_fp32(
&output_int32_ref,
&output_fp32_ref,
in_scale[0],
w_scale);
param.output = &output_fp32;
param.input_scale = in_scale[0];
param.output_scale = 1;
param.weight_scale = w_scale;
std::unique_ptr<KernelContext> ctx_fp32(
new KernelContext);
lite::arm::math::DirectConvInt8<PRECISION(kFloat)>
int8direct_fp32;
int8direct_fp32.init(param,
&ctx_fp32->As<ARMContext>());
int8direct_fp32.create(param,
&ctx_fp32->As<ARMContext>());
int8direct_fp32.run(param);
float* output_fp32_data =
output_fp32.mutable_data<float>();
for (int i = 0; i < output_fp32.dims().production();
i++) {
EXPECT_NEAR(output_fp32_data[i],
output_fp32_ref_data[i],
1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
TEST(conv_depthwise_int8, compute) {
DeviceInfo::Init();
for (auto n : {1, 2}) {
for (auto ic : {1, 3, 8}) {
for (auto ih : {5, 15, 28}) {
for (auto iw : {5, 15, 28}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto dilation : {1}) {
for (auto stride : {1, 2}) {
for (auto padding : {1, 2}) {
for (auto ks : {3, /*5 */}) {
int group = ic;
int oc = ic;
bool flag_dw_3x3 = (ks == 3) && (padding == 1) &&
(stride == 1 || stride == 2);
bool flag_dw_5x5 =
(ks == 5 && stride == 1 && padding == 2);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (!flag_dw) continue;
const int dks = dilation * (ks - 1) + 1;
int oh = (ih + 2 * padding - dks) / stride + 1;
int ow = (iw + 2 * padding - dks) / stride + 1;
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {
oc, ic / group, ks, ks};
std::vector<int64_t> bias_shape({1, oc, 1, 1});
std::vector<int64_t> output_shape({n, oc, oh, ow});
Tensor input_fp32, input_int8;
Tensor filter_fp32, filter_int8;
Tensor bias_int32;
Tensor output_int32_ref, output_int32;
Tensor output_fp32_ref, output_fp32;
Tensor output_int8_ref, output_int8;
input_fp32.Resize(input_shape);
input_int8.Resize(input_shape);
filter_fp32.Resize(filter_shape);
filter_int8.Resize(filter_shape);
bias_int32.Resize(bias_shape);
output_int32.Resize(output_shape);
output_int32_ref.Resize(output_shape);
output_fp32_ref.Resize(output_shape);
output_fp32.Resize(output_shape);
output_int8_ref.Resize(output_shape);
output_int8.Resize(output_shape);
float* input_fp32_data = input_fp32.mutable_data<float>();
int8_t* input_int8_data =
input_int8.mutable_data<int8_t>();
float* filter_fp32_data =
filter_fp32.mutable_data<float>();
int8_t* filter_int8_data =
filter_int8.mutable_data<int8_t>();
int* bias_int32_data = bias_int32.mutable_data<int32_t>();
for (int i = 0; i < input_fp32.dims().production(); i++) {
input_fp32_data[i] = i % 10 * (i % 3 - 1);
}
for (int i = 0; i < filter_fp32.dims().production();
i++) {
filter_fp32_data[i] = i % 10 * (i % 3 - 1);
}
for (int i = 0; i < bias_int32.dims().production(); i++) {
bias_int32_data[i] = i % 10 * (i % 3 - 1);
}
std::vector<float> in_scale;
lite::arm::math::get_tensor_scale<PRECISION(kFloat)>(
input_fp32, &in_scale, -1, 127.f);
lite::arm::math::trans_tensor_fp32_to_int8(
&input_fp32, &input_int8, in_scale[0]);
std::vector<float> w_scale;
lite::arm::math::get_tensor_scale<PRECISION(kFloat)>(
filter_fp32, &w_scale, -1, 127.f);
int axis_size = oc;
int inner_size = ic / group * ks * ks;
w_scale = lite::arm::math::get_tensor_scale_n(
filter_fp32_data, axis_size, inner_size, 127.f);
lite::arm::math::fp32_to_int8(filter_fp32_data,
filter_int8_data,
w_scale.data(),
axis_size,
1,
inner_size);
operators::ConvParam param;
param.x = &input_int8;
param.filter = &filter_int8;
if (flag_bias) {
param.bias = &bias_int32;
}
param.fuse_relu = false;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations = std::vector<int>({dilation, dilation});
param.groups = group;
param.output = &output_int32_ref;
conv_compute_ref<int8_t, int>(param);
int* output_int32_ref_data =
output_int32_ref.mutable_data<int>();
// ============ int8depthwise_int32 ============
param.output = &output_int32;
std::unique_ptr<KernelContext> ctx_int32(
new KernelContext);
lite::arm::math::DepthwiseConvInt8<PRECISION(kInt32)>
int8depthwise_int32;
int8depthwise_int32.init(param,
&ctx_int32->As<ARMContext>());
int8depthwise_int32.create(param,
&ctx_int32->As<ARMContext>());
int8depthwise_int32.run(param);
int* output_int32_data = output_int32.mutable_data<int>();
for (int i = 0; i < output_int32.dims().production();
i++) {
EXPECT_NEAR(output_int32_data[i],
output_int32_ref_data[i],
1e-3);
}
// ============ int8depthwise_int8============
int8_t* output_int8_ref_data =
output_int8_ref.mutable_data<int8_t>();
lite::arm::math::trans_tensor_int32_to_int8(
&output_int32_ref,
&output_int8_ref,
in_scale[0],
1,
w_scale);
param.output = &output_int8;
param.input_scale = in_scale[0];
param.output_scale = 1;
param.weight_scale = w_scale;
std::unique_ptr<KernelContext> ctx_int8(
new KernelContext);
lite::arm::math::DepthwiseConvInt8<PRECISION(kInt8)>
int8depthwise_int8;
int8depthwise_int8.init(param,
&ctx_int8->As<ARMContext>());
int8depthwise_int8.create(param,
&ctx_int8->As<ARMContext>());
int8depthwise_int8.run(param);
int8_t* output_int8_data =
output_int8.mutable_data<int8_t>();
for (int i = 0; i < output_int8.dims().production();
i++) {
EXPECT_NEAR(
output_int8_data[i], output_int8_ref_data[i], 1e-3);
}
// ============int8depthwise_float32 ============
float* output_fp32_ref_data =
output_fp32_ref.mutable_data<float>();
lite::arm::math::trans_tensor_int32_to_fp32(
&output_int32_ref,
&output_fp32_ref,
in_scale[0],
w_scale);
param.output = &output_fp32;
param.input_scale = in_scale[0];
param.output_scale = 1;
param.weight_scale = w_scale;
std::unique_ptr<KernelContext> ctx_fp32(
new KernelContext);
lite::arm::math::DepthwiseConvInt8<PRECISION(kFloat)>
int8depthwise_fp32;
int8depthwise_fp32.init(param,
&ctx_fp32->As<ARMContext>());
int8depthwise_fp32.create(param,
&ctx_fp32->As<ARMContext>());
int8depthwise_fp32.run(param);
float* output_fp32_data =
output_fp32.mutable_data<float>();
for (int i = 0; i < output_fp32.dims().production();
i++) {
EXPECT_NEAR(
output_fp32_data[i], output_fp32_ref_data[i], 1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
TEST(conv_arm, compute) {
DeviceInfo::Init();
#if 1
for (auto n : {2}) {
for (auto ic : {6}) {
for (auto oc : {6}) {
for (auto ih : {9}) {
for (auto iw : {9}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1, 2}) {
for (auto ks : {1, 3, 5}) {
#else
for (auto n : {1, 2}) {
for (auto ic : {6, 32 /*, 128*/}) {
for (auto oc : {6, 32 /*, 128*/}) {
for (auto ih : {9, 18 /*, 56 , 112, 224, 512*/}) {
for (auto iw : {9, 18 /*, 56, 112, 224, 512*/}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1, 2}) {
for (auto ks : {1, 3, 5}) {
#endif
int group = 1;
if (depthwise) { // depthwise convolution ?
group = oc = ic;
}
// get input, filter and output shape
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {
oc, ic / group, ks, ks};
const int dks = dilation * (ks - 1) + 1;
int oh = (ih + 2 * padding - dks) / stride + 1;
int ow = (iw + 2 * padding - dks) / stride + 1;
std::vector<int64_t> output_shape({n, oc, oh, ow});
// resize input, filter and output
Tensor input;
Tensor filter;
Tensor bias;
Tensor output;
Tensor output_ref;
input.Resize(input_shape);
filter.Resize(filter_shape);
output.Resize(output_shape);
output_ref.Resize(output_shape);
VLOG(3) << "input: " << input.dims();
VLOG(3) << "filter: " << filter.dims()
<< " padding:" << padding
<< " stride:" << stride
<< " dilation:" << dilation;
VLOG(3) << "output: " << output.dims();
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
input_data[i] = sign * static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] =
i * 0.001f /
static_cast<float>(filter.dims().production());
}
// prepare kernel params and run
ConvCompute conv;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
conv.SetContext(std::move(ctx));
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
if (flag_bias) {
bias.Resize({oc});
auto* bias_data = bias.mutable_data<float>();
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i);
}
param.bias = &bias;
}
param.fuse_relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Launch();
// invoking ref implementation and compare results
param.output = &output_ref;
conv_compute_ref<float, float>(param);
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(
output_data[i], output_ref_data[i], 1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
...@@ -12,57 +12,142 @@ ...@@ -12,57 +12,142 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/backends/arm/math/conv_depthwise.h" #include "lite/kernels/arm/conv_depthwise.h"
#include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_impl.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels {
namespace arm { namespace arm {
namespace math {
template <> template <>
bool DepthwiseConv<PRECISION(kFloat)>::create(const operators::ConvParam& param, void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
ARMContext* ctx) { auto& param = this->Param<param_t>();
this->ctx_ = ctx; CHECK(this->ctx_);
auto x_dims = param.x->dims(); auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto kw = w_dims[3];
int iw = x_dims[3]; // nchw
int ic = x_dims[1];
int ow = o_dims[3];
int oc = o_dims[1];
int kw = w_dims[3];
int sw = param.strides[1];
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv"; VLOG(5) << "invoke 3x3 dw conv fp32";
impl_ = conv_depthwise_3x3; /// trans weights
constexpr int cblock = 4;
auto oc = w_dims[0];
auto kh = w_dims[2];
auto cround = ROUNDUP(oc, cblock);
weights_.Resize({cround, 1, kh, kw});
auto w_data = weights_.mutable_data<float>();
auto w_data_in = param.filter->data<float>();
lite::arm::math::conv_trans_weights_numc(
w_data_in, w_data, oc, 1, cblock, kh * kw);
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
flag_trans_weights_ = true;
} else if (kw == 5) { } else if (kw == 5) {
VLOG(5) << "invoke 5x5 dw conv"; VLOG(5) << "invoke 5x5 dw conv fp32";
this->ctx_->ExtendWorkspace((iw + ow) * sizeof(float)); impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
impl_ = conv_depthwise_5x5; } else {
LOG(FATAL) << "this type dw conv not impl";
}
}
template <>
void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
auto& param = this->Param<param_t>();
CHECK(this->ctx_);
auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims();
int kh = w_dims[2];
int kw = w_dims[3];
int oc = w_dims[0];
/// update scale
float in_scale = param.input_scale;
auto& scale = param.weight_scale;
CHECK(scale.size() == 1 || scale.size() == oc)
<< "weights scale size must = filter size or = 1";
w_scale_.resize(oc);
for (int i = 0; i < oc; ++i) {
if (scale.size() == 1) {
w_scale_[i] = scale[0] * in_scale;
} else {
w_scale_[i] = scale[i] * in_scale;
}
}
/// select dw conv kernel
if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else { } else {
LOG(ERROR) << "this type dw conv not impl"; LOG(FATAL) << "this type dw conv not impl";
return false;
} }
return true;
} }
template <> template <>
bool DepthwiseConv<PRECISION(kFloat)>::init(const operators::ConvParam& param, void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
Context<TARGET(kARM)>* ctx) { auto& param = this->Param<param_t>();
this->ctx_ = ctx; CHECK(this->ctx_);
return create(param, ctx); auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims();
int kw = w_dims[3];
int kh = w_dims[2];
int oc = w_dims[0];
/// update scale
float in_scale = param.input_scale;
float out_scale = param.output_scale;
auto& scale = param.weight_scale;
CHECK(scale.size() == 1 || scale.size() == oc)
<< "weights scale size must = filter size or = 1";
w_scale_.resize(oc);
for (int i = 0; i < oc; ++i) {
if (scale.size() == 1) {
w_scale_[i] = scale[0] * in_scale / out_scale;
} else {
w_scale_[i] = scale[i] * in_scale / out_scale;
}
}
/// update bias
if (param.bias) {
bias_.Resize(param.bias->dims());
auto ptr = bias_.mutable_data<float>();
auto ptr_in = param.bias->data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] = ptr_in[i] / out_scale;
}
flag_trans_bias_ = true;
}
/// select dw conv kernel
if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else {
LOG(FATAL) << "this type dw conv not impl";
}
} }
template <> template <>
bool DepthwiseConv<PRECISION(kFloat)>::run(const operators::ConvParam& param) { void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
// start timer auto& param = this->Param<param_t>();
CHECK(this->ctx_);
auto& ctx = this->ctx_->template As<ARMContext>();
const auto* i_data = param.x->data<float>(); const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>(); const auto* w_data = flag_trans_weights_ ? weights_.data<float>()
: param.filter->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr; const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
auto* o_data = param.output->mutable_data<float>(); auto* o_data = param.output->mutable_data<float>();
auto x_dims = param.x->dims(); auto x_dims = param.x->dims();
...@@ -89,111 +174,77 @@ bool DepthwiseConv<PRECISION(kFloat)>::run(const operators::ConvParam& param) { ...@@ -89,111 +174,77 @@ bool DepthwiseConv<PRECISION(kFloat)>::run(const operators::ConvParam& param) {
w_data, w_data,
b_data, b_data,
param, param,
this->ctx_); &ctx,
w_scale_.data());
// timer end
return true;
} }
template <PrecisionType Ptype_out> template <>
bool DepthwiseConvInt8<Ptype_out>::create(const operators::ConvParam& param, void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
ARMContext* ctx) { auto& param = this->Param<param_t>();
this->ctx_ = ctx; CHECK(this->ctx_);
auto& ctx = this->ctx_->template As<ARMContext>();
const auto* i_data = param.x->data<int8_t>();
const auto* w_data = flag_trans_weights_ ? weights_.data<int8_t>()
: param.filter->data<int8_t>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
auto* o_data = param.output->mutable_data<float>();
auto x_dims = param.x->dims(); auto x_dims = param.x->dims();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
int ic = x_dims[1];
int ih = x_dims[2];
int iw = x_dims[3]; // nchw int iw = x_dims[3]; // nchw
int oc = o_dims[1]; int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2]; int oh = o_dims[2];
int ow = o_dims[3]; int ow = o_dims[3];
int kw = w_dims[3]; int oc = o_dims[1];
int sw = param.strides[1];
w_scale_ = param.weight_scale;
//! select dw conv kernel
if (kw == 3) {
tmp_int32_out_.Resize(o_dims);
VLOG(5) << "invoke 3x3 depthwise int8 conv";
impl_ = conv_depthwise_3x3_int8;
} else if (kw == 5) {
// update w_data scale
if (Ptype_out == PRECISION(kFloat) || Ptype_out == PRECISION(kInt8)) {
CHECK_EQ(w_scale_.size(), oc) << "w_data scale size must be oc";
float input_scale = param.input_scale;
float output_scale = param.output_scale;
for (auto& ws : w_scale_) {
ws *= input_scale;
if (Ptype_out == PRECISION(kInt8)) {
ws /= output_scale;
}
}
}
const int wout_round = ((ow + 7) / 8) * 8;
const int win_round = wout_round * sw + 5 - 1;
const int hout_round = ((oh + 2) / 3) * 3;
const int hin_round = hout_round * sw + 5 - 1;
const int tmp_size_out = wout_round * hout_round;
const int tmp_size_in = win_round * hin_round;
const int tmp_size_io_bytes = tmp_size_in + tmp_size_out * sizeof(int);
const int tmp_row_io_bytes = win_round + wout_round * sizeof(int);
const int tmp_size_io_float =
(tmp_size_io_bytes + sizeof(float) - 1) / sizeof(float);
const int tmp_row_io_float =
(tmp_row_io_bytes + sizeof(float) - 1) / sizeof(float);
ctx_->ExtendWorkspace(
(ctx_->threads() * tmp_size_io_float + tmp_row_io_float) *
sizeof(float));
impl_ = conv_depthwise_5x5_int8;
VLOG(5) << "invoke conv_depthwise_5x5 int8 conv";
} else {
LOG(ERROR) << "this type depthwise int8 conv not impl";
return false;
}
return true;
}
template <PrecisionType Ptype_out> impl_(i_data,
bool DepthwiseConvInt8<Ptype_out>::init(const operators::ConvParam& param, o_data,
Context<TARGET(kARM)>* ctx) { bs,
this->ctx_ = ctx; oc,
return create(param, ctx); oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx,
w_scale_.data());
} }
template <PrecisionType Ptype_out> template <>
bool DepthwiseConvInt8<Ptype_out>::run(const operators::ConvParam& param) { void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
const int8_t* i_data = param.x->data<int8_t>(); auto& param = this->Param<param_t>();
int32_t* o_data = nullptr; CHECK(this->ctx_);
const int8_t* w_data = param.filter->data<int8_t>(); auto& ctx = this->ctx_->template As<ARMContext>();
const int32_t* b_data = param.bias ? param.bias->data<int32_t>() : nullptr; const auto* i_data = param.x->data<int8_t>();
const auto* w_data = flag_trans_weights_ ? weights_.data<int8_t>()
// LOG(INFO) << "input size: " << param.x->memory_size() << " " : param.filter->data<int8_t>();
// << param.input_scale << " " << w_scale_.size(); const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
auto* o_data = param.output->mutable_data<int8_t>();
auto x_dims = param.x->dims(); auto x_dims = param.x->dims();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
int bs = x_dims[0];
int ic = x_dims[1];
int ih = x_dims[2];
int iw = x_dims[3]; // nchw int iw = x_dims[3]; // nchw
int oc = o_dims[1]; int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2]; int oh = o_dims[2];
int ow = o_dims[3]; int ow = o_dims[3];
int kw = w_dims[3]; int oc = o_dims[1];
int sw = param.strides[1];
if (kw == 3 && Ptype_out != PRECISION(kInt32)) {
o_data = tmp_int32_out_.mutable_data<int32_t>();
} else if (kw == 5 || (kw == 3 && Ptype_out == PRECISION(kInt32))) {
o_data = param.output->mutable_data<int32_t>();
} else {
LOG(ERROR) << "this type dw int8 conv not impl";
return false;
}
impl_(i_data, impl_(i_data,
o_data, o_data,
...@@ -207,33 +258,11 @@ bool DepthwiseConvInt8<Ptype_out>::run(const operators::ConvParam& param) { ...@@ -207,33 +258,11 @@ bool DepthwiseConvInt8<Ptype_out>::run(const operators::ConvParam& param) {
w_data, w_data,
b_data, b_data,
param, param,
this->ctx_, &ctx,
Ptype_out,
w_scale_.data()); w_scale_.data());
auto i_scale = param.input_scale;
auto o_scale = param.output_scale;
if (kw == 3) {
if (Ptype_out == PRECISION(kInt8)) {
trans_tensor_dtype<PRECISION(kInt32), PRECISION(kInt8)>(
&tmp_int32_out_, param.output, i_scale, o_scale, w_scale_);
} else if (Ptype_out == PRECISION(kFloat)) {
trans_tensor_dtype<PRECISION(kInt32), PRECISION(kFloat)>(
&tmp_int32_out_, param.output, i_scale, 1.f, w_scale_);
} else if (Ptype_out != PRECISION(kInt32)) {
LOG(ERROR) << "unsupported precision type!!";
return false;
}
}
return true;
} }
template class DepthwiseConvInt8<PRECISION(kInt8)>;
template class DepthwiseConvInt8<PRECISION(kFloat)>;
template class DepthwiseConvInt8<PRECISION(kInt32)>;
} // namespace math
} // namespace arm } // namespace arm
} // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <PrecisionType Ptype, PrecisionType Otype>
class DepthwiseConv : public KernelLite<TARGET(kARM), Ptype> {
public:
typedef void (*conv_dw_impl)(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale);
DepthwiseConv() = default;
~DepthwiseConv() {}
virtual void PrepareForRun();
virtual void Run();
private:
using param_t = operators::ConvParam;
Tensor weights_;
Tensor bias_;
bool flag_trans_weights_{false};
bool flag_trans_bias_{false};
conv_dw_impl impl_{nullptr};
std::vector<float> w_scale_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/conv_direct.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
const auto* i_data = param.x->data<float>();
const auto* w_data = weights_.data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
if (param.strides[0] == 1) {
lite::arm::math::conv_3x3s1_direct_fp32(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
} else {
lite::arm::math::conv_3x3s2_direct_fp32(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
}
}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
const auto* i_data = param.x->data<int8_t>();
const auto* w_data = weights_.data<int8_t>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
auto* o_data = param.output->mutable_data<float>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
if (param.strides[0] == 1) {
lite::arm::math::conv_3x3s1_direct_int8(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx,
w_scale_.data());
} else {
lite::arm::math::conv_3x3s2_direct_int8(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx,
w_scale_.data());
}
}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
const auto* i_data = param.x->data<int8_t>();
const auto* w_data = weights_.data<int8_t>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
auto* o_data = param.output->mutable_data<int8_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
if (param.strides[0] == 1) {
lite::arm::math::conv_3x3s1_direct_int8(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx,
w_scale_.data());
} else {
lite::arm::math::conv_3x3s2_direct_int8(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx,
w_scale_.data());
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
template <PrecisionType Ptype, PrecisionType OutType>
inline bool direct_conv_trans_weights(
const Tensor* win,
Tensor* wout,
const Tensor* bin,
Tensor* bout,
int stride,
const std::vector<float>& w_scale,
float in_scale,
float out_scale,
std::vector<float>& merge_scale) { // NOLINT
constexpr int cblock = 4;
int oc = win->dims()[0];
int ic = win->dims()[1];
int kh = win->dims()[2];
int kw = win->dims()[3];
int cround = ROUNDUP(oc, cblock);
wout->Resize({cround, ic, kh, kw});
auto w_in_data = win->data<float>();
auto transed_w_data = wout->mutable_data<float>();
lite::arm::math::conv_trans_weights_numc(
w_in_data, transed_w_data, oc, ic, cblock, kh * kw);
return false;
}
template <>
inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kFloat)>(
const Tensor* win,
Tensor* wout,
const Tensor* bin,
Tensor* bout,
int stride,
const std::vector<float>& w_scale,
float in_scale,
float out_scale,
std::vector<float>& merge_scale) { // NOLINT
int cblock = 4;
if (stride == 2) {
cblock = lite::arm::math::conv_3x3s2_direct_int8_c_num();
}
int oc = win->dims()[0];
int ic = win->dims()[1];
int kh = win->dims()[2];
int kw = win->dims()[3];
int cround = ROUNDUP(oc, cblock);
wout->Resize({cround, ic, kh, kw});
auto w_in_data = win->data<int8_t>();
auto transed_w_data = wout->mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(
w_in_data, transed_w_data, oc, ic, cblock, kh * kw);
/// update scale
CHECK(w_scale.size() == 1 || w_scale.size() == oc)
<< "weights scale size must = filter size or = 1";
merge_scale.resize(oc);
for (int i = 0; i < oc; ++i) {
if (w_scale.size() == 1) {
merge_scale[i] = w_scale[0] * in_scale;
} else {
merge_scale[i] = w_scale[i] * in_scale;
}
}
return false;
}
template <>
inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kInt8)>(
const Tensor* win,
Tensor* wout,
const Tensor* bin,
Tensor* bout,
int stride,
const std::vector<float>& w_scale,
float in_scale,
float out_scale,
std::vector<float>& merge_scale) { // NOLINT
int cblock = 4;
if (stride == 2) {
cblock = lite::arm::math::conv_3x3s2_direct_int8_c_num();
}
int oc = win->dims()[0];
int ic = win->dims()[1];
int kh = win->dims()[2];
int kw = win->dims()[3];
int cround = ROUNDUP(oc, cblock);
wout->Resize({cround, ic, kh, kw});
auto w_in_data = win->data<int8_t>();
auto transed_w_data = wout->mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(
w_in_data, transed_w_data, oc, ic, cblock, kh * kw);
/// update scale
CHECK(w_scale.size() == 1 || w_scale.size() == oc)
<< "weights scale size must = filter size or = 1";
merge_scale.resize(oc);
float scale = in_scale / out_scale;
for (int i = 0; i < oc; ++i) {
if (w_scale.size() == 1) {
merge_scale[i] = w_scale[0] * scale;
} else {
merge_scale[i] = w_scale[i] * scale;
}
}
/// update bias
if (bin) {
bout->Resize(bin->dims());
auto ptr = bout->mutable_data<float>();
auto ptr_in = bin->data<float>();
for (int i = 0; i < bin->numel(); ++i) {
ptr[i] = ptr_in[i] / out_scale;
}
return true;
}
return false;
}
/// only support 3x3s1 and 3x3s2
template <PrecisionType Ptype, PrecisionType OutType>
class DirectConv : public KernelLite<TARGET(kARM), Ptype> {
public:
DirectConv() = default;
~DirectConv() {}
virtual void PrepareForRun() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int ic = x_dims[1];
int oc = o_dims[1];
int sw = param.strides[1];
int kw = w_dims[3];
int kh = w_dims[2];
CHECK(sw == 1 || sw == 2)
<< "direct conv only support conv3x3s1 and conv3x3s2";
CHECK(kw == 3 && kh == 3)
<< "direct conv only support conv3x3s1 and conv3x3s2";
flag_trans_bias_ =
direct_conv_trans_weights<Ptype, OutType>(param.filter,
&weights_,
param.bias,
&bias_,
sw,
param.weight_scale,
param.input_scale,
param.output_scale,
w_scale_);
}
virtual void Run();
/// todo, support inplace weights transform
protected:
Tensor weights_;
Tensor bias_;
bool flag_trans_weights_{false};
bool flag_trans_bias_{false};
std::vector<float> w_scale_;
private:
using param_t = operators::ConvParam;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/conv_gemmlike.h"
#include <vector>
#include "lite/backends/arm/math/gemm_prepacked_int8.h"
#include "lite/backends/arm/math/packed_sgemm.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <>
void GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
ReInitWhenNeeded();
}
template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
ReInitWhenNeeded();
auto& param = this->Param<param_t>();
/// update scale
w_scale_ = param.weight_scale;
if (w_scale_.size() != 1 && w_scale_.size() != param.filter->dims()[0]) {
LOG(FATAL) << "weights scale size must equal to filter size";
return;
}
if (w_scale_.size() == 1) {
for (int i = 0; i < param.filter->dims()[0] - 1; ++i) {
w_scale_.push_back(w_scale_[0]);
}
}
float input_scale = param.input_scale;
for (auto& ws : w_scale_) {
ws *= input_scale;
}
}
template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
ReInitWhenNeeded();
auto& param = this->Param<param_t>();
/// update scale
/// update scale
w_scale_ = param.weight_scale;
if (w_scale_.size() != 1 && w_scale_.size() != param.filter->dims()[0]) {
LOG(FATAL) << "weights scale size must equal to filter size";
return;
}
if (w_scale_.size() == 1) {
for (int i = 0; i < param.filter->dims()[0] - 1; ++i) {
w_scale_.push_back(w_scale_[0]);
}
}
float input_scale = param.input_scale;
float output_scale = param.output_scale;
for (auto& ws : w_scale_) {
ws = ws * input_scale / output_scale;
}
//! update bias
if (param.bias) {
bias_.Resize(param.bias->dims());
auto ptr = bias_.mutable_data<float>();
auto ptr_in = param.bias->data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] = ptr_in[i] / param.output_scale;
}
flag_trans_bias_ = true;
}
}
template <>
void GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto weights = param.filter->data<float>();
if (flag_trans_weights_) {
weights = weights_.data<float>();
}
const float* bias = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
bias = bias_.data<float>();
}
auto din = param.x->data<float>();
auto dout = param.output->mutable_data<float>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
if (flag_1x1gemm_) {
lite::arm::math::conv1x1s1_gemm(
din, dout, bs, oc, oh, ow, ic, ih, iw, weights, bias, param, &ctx);
} else {
lite::arm::math::conv_im2col_gemm(
din, dout, bs, oc, oh, ow, ic, ih, iw, weights, bias, param, &ctx);
}
}
template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto weights = param.filter->data<int8_t>();
if (flag_trans_weights_) {
weights = weights_.data<int8_t>();
}
auto bias = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
bias = bias_.data<float>();
}
auto din = param.x->data<int8_t>();
auto dout = param.output->mutable_data<float>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
if (flag_1x1gemm_) {
lite::arm::math::conv1x1s1_gemm_int8(din,
dout,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
weights,
bias,
param,
&ctx,
w_scale_.data());
} else {
lite::arm::math::conv_im2col_gemm_int8(din,
dout,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
weights,
bias,
param,
&ctx,
w_scale_.data());
}
}
template <>
void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto weights = param.filter->data<int8_t>();
if (flag_trans_weights_) {
weights = weights_.data<int8_t>();
}
auto bias = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
bias = bias_.data<float>();
}
auto din = param.x->data<int8_t>();
auto dout = param.output->mutable_data<int8_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
if (flag_1x1gemm_) {
lite::arm::math::conv1x1s1_gemm_int8(din,
dout,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
weights,
bias,
param,
&ctx,
w_scale_.data());
} else {
lite::arm::math::conv_im2col_gemm_int8(din,
dout,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
weights,
bias,
param,
&ctx,
w_scale_.data());
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <PrecisionType Ptype, PrecisionType Otype>
class GemmLikeConv : public KernelLite<TARGET(kARM), Ptype> {
public:
GemmLikeConv() = default;
~GemmLikeConv() {}
virtual void ReInitWhenNeeded() {
auto& param = this->template Param<param_t>();
CHECK(this->ctx_);
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
if (last_shape_ == x_dims) {
return;
}
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int kw = w_dims[3];
int kh = w_dims[2];
int sw = param.strides[1];
int sh = param.strides[0];
int pw = param.paddings[1];
int ph = param.paddings[0];
int dw = param.dilations[1];
int dh = param.dilations[0];
int m = oc / param.groups;
int k = ic * kh * kw / param.groups;
int n = oh * ow;
bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh);
bool ks_equal = (sw == sh) && (kw == kh);
//! select conv gemmlike kernel
if (kw == 1 && sw == 1 && pw == 0 && kps_equal) {
//! 1x1s1p0 gemmlike conv
flag_1x1gemm_ = true;
} else {
//! im2col gemmlike conv
flag_1x1gemm_ = false;
ctx.ExtendWorkspace(k * n * sizeof(float));
}
if (!flag_trans_weights_ && n > 1) {
lite::arm::math::trans_gemm_weights<Ptype>(
*(param.filter), weights_, param.groups, &ctx);
flag_trans_weights_ = true;
} else if (n == 1) {
flag_trans_weights_ = false;
}
last_shape_ = x_dims;
}
virtual void PrepareForRun();
virtual void Run();
/// todo, support inplace weights transform
protected:
using param_t = operators::ConvParam;
DDim last_shape_;
std::vector<float> w_scale_;
bool flag_1x1gemm_{true};
bool flag_trans_weights_{false};
bool flag_trans_bias_{false};
Tensor weights_;
Tensor bias_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -80,7 +80,7 @@ void Conv2DTransposeCompute::Run() { ...@@ -80,7 +80,7 @@ void Conv2DTransposeCompute::Run() {
int group_size_out = wout * hout * chout / group; int group_size_out = wout * hout * chout / group;
int group_size_coldata = m * n; int group_size_coldata = m * n;
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
int hblock = lite::arm::math::get_hblock(ctx.arch()); int hblock = lite::arm::math::get_hblock(&ctx);
int m_roundup = hblock * ((m + hblock - 1) / hblock); int m_roundup = hblock * ((m + hblock - 1) / hblock);
int group_size_weights = ((m_roundup * k + 15) / 16) * 16; int group_size_weights = ((m_roundup * k + 15) / 16) * 16;
bool flag_1x1s1p1 = (kw == 1) && (kh == 1) && (param.strides[0] == 1) && bool flag_1x1s1p1 = (kw == 1) && (kh == 1) && (param.strides[0] == 1) &&
......
...@@ -12,99 +12,105 @@ ...@@ -12,99 +12,105 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/backends/arm/math/conv_winograd.h" #include "lite/kernels/arm/conv_winograd.h"
#include <vector> #include <vector>
#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/packed_sgemm.h" #include "lite/backends/arm/math/packed_sgemm.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels {
namespace arm { namespace arm {
namespace math {
template <> template <>
bool WinogradConv<PRECISION(kFloat)>::create(const operators::ConvParam& param, void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
ARMContext* ctx) { auto& param = this->Param<param_t>();
this->ctx_ = ctx; auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.x->dims(); auto x_dims = param.x->dims();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw if (last_shape_ == x_dims) {
return;
}
int ic = x_dims[1]; int ic = x_dims[1];
int ow = o_dims[3]; int ow = o_dims[3];
int oh = o_dims[2]; int oh = o_dims[2];
int oc = o_dims[1]; int oc = o_dims[1];
int kw = w_dims[3]; int tile_w = (ow + 5) / 6;
int sw = param.strides[1]; int tile_h = (oh + 5) / 6;
if (kw == 3) { int size_tile = tile_h * tile_w;
is_weights_transed_ = true; int size_trans_channel = 8 * 8 * size_tile;
int tile_w = (ow + 5) / 6; int max_ch = ic > oc ? ic : oc;
int tile_h = (oh + 5) / 6;
int size_tile = tile_h * tile_w; const int n_wino = size_tile;
int size_trans_channel = 8 * 8 * size_tile; ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
int max_ch = ic > oc ? ic : oc; sizeof(float));
last_shape_ = x_dims;
const int m_wino = oc;
const int n_wino = size_tile;
int hblock = get_hblock(this->ctx_->arch());
int m_round = hblock * ((m_wino + hblock - 1) / hblock);
weights_trans_.Resize({1, 1, 1, 8 * 8 * m_round * ic});
this->ctx_->ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
auto weights_wino =
static_cast<float*>(malloc(sizeof(float) * 8 * 8 * oc * ic));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
if (weights_wino && trans_tmp_ptr) {
winograd_transform_weights(
weights_wino, param.filter->data<float>(), oc, ic, trans_tmp_ptr);
auto weights_trans = weights_trans_.mutable_data<float>();
for (int i = 0; i < 64; ++i) {
float* packed_weights = weights_trans + i * m_round * ic;
const float* weights_wino_ptr = weights_wino + i * oc * ic;
prepackA(packed_weights,
weights_wino_ptr,
1.f,
ic,
0,
m_wino,
0,
ic,
false,
this->ctx_);
}
impl_ = conv_winograd3x3;
free(trans_tmp_ptr);
free(weights_wino);
return true;
}
free(trans_tmp_ptr);
free(weights_wino);
} else {
LOG(ERROR) << "this type winograd conv not impl";
}
return false;
} }
template <> template <>
bool WinogradConv<PRECISION(kFloat)>::init(const operators::ConvParam& param, void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
Context<TARGET(kARM)>* ctx) { auto& param = this->Param<param_t>();
this->ctx_ = ctx; auto& ctx = this->ctx_->template As<ARMContext>();
return create(param, ctx);
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
last_shape_ = x_dims;
int ic = x_dims[1];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int tile_w = (ow + 5) / 6;
int tile_h = (oh + 5) / 6;
int size_tile = tile_h * tile_w;
int size_trans_channel = 8 * 8 * size_tile;
int max_ch = ic > oc ? ic : oc;
const int m_wino = oc;
const int n_wino = size_tile;
int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_wino + hblock - 1) / hblock);
weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic});
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
auto weights_wino =
static_cast<float*>(malloc(sizeof(float) * 8 * 8 * oc * ic));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
lite::arm::math::winograd_transform_weights(
weights_wino, param.filter->data<float>(), oc, ic, trans_tmp_ptr);
auto weights_trans = weights_.mutable_data<float>();
for (int i = 0; i < 64; ++i) {
float* packed_weights = weights_trans + i * m_round * ic;
const float* weights_wino_ptr = weights_wino + i * oc * ic;
lite::arm::math::prepackA(packed_weights,
weights_wino_ptr,
1.f,
ic,
0,
m_wino,
0,
ic,
false,
&ctx);
}
free(trans_tmp_ptr);
free(weights_wino);
} }
template <> template <>
bool WinogradConv<PRECISION(kFloat)>::run(const operators::ConvParam& param) { void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
// start timer auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
const auto* i_data = param.x->data<float>(); const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>(); const auto* w_data = weights_.data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr; const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>(); auto* o_data = param.output->mutable_data<float>();
if (is_weights_transed_) {
w_data = weights_trans_.data<float>();
}
auto x_dims = param.x->dims(); auto x_dims = param.x->dims();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
...@@ -117,25 +123,11 @@ bool WinogradConv<PRECISION(kFloat)>::run(const operators::ConvParam& param) { ...@@ -117,25 +123,11 @@ bool WinogradConv<PRECISION(kFloat)>::run(const operators::ConvParam& param) {
int ow = o_dims[3]; int ow = o_dims[3];
int oc = o_dims[1]; int oc = o_dims[1];
impl_(i_data, lite::arm::math::conv_winograd3x3(
o_data, i_data, o_data, bs, oc, oh, ow, ic, ih, iw, w_data, b_data, param, &ctx);
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
this->ctx_);
// timer end
return true;
} }
} // namespace math
} // namespace arm } // namespace arm
} // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -17,49 +17,31 @@ ...@@ -17,49 +17,31 @@
#include <cmath> #include <cmath>
#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels {
namespace arm { namespace arm {
namespace math {
template <PrecisionType Ptype> /// only support 3x3s1 and 3x3s2
class WinogradConv template <PrecisionType Ptype, PrecisionType OutType>
: public ImplBase<TARGET(kARM), Ptype, operators::ConvParam> { class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
public: public:
typedef void (*conv_winograd_impl)(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
Context<TARGET(kARM)>* ctx);
WinogradConv() = default; WinogradConv() = default;
~WinogradConv() {} ~WinogradConv() {}
virtual void PrepareForRun();
virtual bool init(const operators::ConvParam& param, virtual void ReInitWhenNeeded();
Context<TARGET(kARM)>* ctx); virtual void Run();
virtual bool create(const operators::ConvParam& param, protected:
Context<TARGET(kARM)>* ctx); using param_t = operators::ConvParam;
Tensor weights_;
virtual bool run(const operators::ConvParam& param); DDim last_shape_;
private:
conv_winograd_impl impl_{nullptr};
bool is_weights_transed_{false};
Tensor weights_trans_;
}; };
} // namespace math
} // namespace arm } // namespace arm
} // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -26,50 +26,75 @@ namespace lite { ...@@ -26,50 +26,75 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void FcCompute::PrepareForRun() { /// for fp32 kernel
auto& param = this->Param<operators::FcParam>(); template <>
auto x_dims = param.input->dims(); void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto w_dims = param.w->dims(); ReInitWhenNeeded();
}
auto& ctx = this->ctx_->template As<ARMContext>();
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
m_ = x_dims.Slice(0, param.in_num_col_dims).production();
k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
CHECK_EQ(k_, w_dims[0]);
n_ = w_dims[1];
CHECK_EQ(k_, static_cast<int>(w_dims[0]));
if (m_ == 1) { /// for int8 kernel with fp32 output
if (!transed_weight_) { template <>
transed_weight_ = new Tensor; void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
ReInitWhenNeeded();
auto& param = this->template Param<operators::FcParam>();
/// update scale
float input_scale = param.input_scale;
int extend_size = flag_gemm_ ? m_ : n_;
scale_.resize(extend_size);
for (int i = 0; i < extend_size; ++i) {
if (flag_gemm_) {
scale_[i] = param.weight_scale[0] * input_scale;
} else {
scale_[i] = param.weight_scale[i] * input_scale;
} }
transed_weight_->Resize({n_, k_}); }
const auto* w_data = param.w->data<float>(); }
auto* t_data = transed_weight_->mutable_data<float>();
int i = 0;
for (int nn = 0; nn < n_; ++nn) { /// for int8 kernel with int8 output
for (int kk = 0; kk < k_; ++kk) { template <>
t_data[i++] = w_data[kk * n_ + nn]; void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
} ReInitWhenNeeded();
auto& param = this->template Param<operators::FcParam>();
/// update scale
scale_ = param.weight_scale;
float input_scale = param.input_scale;
float output_scale = param.output_scale;
int extend_size = flag_gemm_ ? m_ : n_;
scale_.resize(extend_size);
for (int i = 0; i < extend_size; ++i) {
if (flag_gemm_) {
scale_[i] = param.weight_scale[0] * input_scale / output_scale;
} else {
scale_[i] = param.weight_scale[i] * input_scale / output_scale;
} }
} }
/// update bias
if (param.bias) {
bias_.Resize(param.bias->dims());
auto ptr = bias_.mutable_data<float>();
auto ptr_in = bias_.data<float>();
float out_scale = param.output_scale;
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] = ptr_in[i] / out_scale;
}
flag_trans_bias_ = true;
}
} }
void FcCompute::Run() { template <>
void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<operators::FcParam>(); auto& param = this->Param<operators::FcParam>();
const auto* i_data = param.input->data<float>();
const auto* w_data = param.w->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>();
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
if (m_ > 1) {
auto i_data = param.input->data<float>();
auto o_data = param.output->mutable_data<float>();
auto w_data =
flag_trans_weights_ ? weights_.data<float>() : param.w->data<float>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
if (flag_gemm_) {
lite::arm::math::sgemm(false, lite::arm::math::sgemm(false,
false, false,
m_, m_,
...@@ -83,7 +108,7 @@ void FcCompute::Run() { ...@@ -83,7 +108,7 @@ void FcCompute::Run() {
0.f, 0.f,
o_data, o_data,
n_, n_,
b_data, nullptr,
false, false,
false, false,
&ctx); &ctx);
...@@ -92,134 +117,117 @@ void FcCompute::Run() { ...@@ -92,134 +117,117 @@ void FcCompute::Run() {
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_); lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_);
} }
} else { } else {
CHECK(transed_weight_); for (int i = 0; i < m_; ++i) {
const auto* t_data = transed_weight_->data<float>(); auto i_data_batch = i_data + i * k_;
auto o_data_batch = o_data + i * n_;
lite::arm::math::sgemv(t_data, lite::arm::math::sgemv(w_data,
i_data, i_data_batch,
o_data, o_data_batch,
false, false,
n_, n_,
k_, k_,
b_data != nullptr, param.bias != nullptr,
b_data, b_data,
false); false);
}
} }
} }
template <PrecisionType Ptype_out> template <>
void FcComputeInt8<Ptype_out>::PrepareForRun() { void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<operators::FcParam>(); auto& param = this->Param<operators::FcParam>();
auto x_dims = param.input->dims();
auto w_dims = param.w->dims();
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
if (!tmp_int32_out_) {
tmp_int32_out_ = new Tensor;
tmp_int32_out_->Resize(param.output->dims());
}
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
this->m_ = x_dims.Slice(0, param.in_num_col_dims).production();
this->k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
this->n_ = w_dims[1];
CHECK_EQ(k_, static_cast<int>(w_dims[0]));
if (this->m_ == 1) { auto i_data = param.input->data<int8_t>();
if (!this->transed_weight_) { auto o_data = param.output->mutable_data<float>();
this->transed_weight_ = new Tensor; auto w_data =
flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
if (flag_gemm_) {
lite::arm::math::gemm_s8(false,
false,
m_,
n_,
k_,
i_data,
w_data,
o_data,
nullptr,
false,
false,
scale_.data(),
&ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_);
} }
this->transed_weight_->Resize({this->n_, this->k_}); } else {
const auto* w_data = param.w->template data<int8_t>(); for (int i = 0; i < m_; ++i) {
auto* t_data = this->transed_weight_->template mutable_data<int8_t>(); auto i_data_batch = i_data + i * k_;
int i = 0; auto o_data_batch = o_data + i * n_;
lite::arm::math::gemv_int8(w_data,
for (int nn = 0; nn < this->n_; ++nn) { i_data_batch,
for (int kk = 0; kk < this->k_; ++kk) { o_data_batch,
t_data[i++] = w_data[kk * this->n_ + nn]; false,
} n_,
k_,
scale_.data(),
param.bias != nullptr,
b_data,
false,
&ctx);
} }
} }
if (this->m_ > 1) {
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = hblock * ((this->m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * this->k_);
}
bool with_bias = param.bias;
if (with_bias) {
Tensor temp_tensor;
temp_tensor.CopyDataFrom(*param.bias);
lite::arm::math::trans_fp32_bias_to_int32_basic(
&temp_tensor, param.bias, param.input_scale, param.weight_scale);
}
} }
template <PrecisionType Ptype_out> template <>
void FcComputeInt8<Ptype_out>::Run() { void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<operators::FcParam>(); auto& param = this->Param<operators::FcParam>();
const auto* i_data = param.input->template data<int8_t>();
const auto* w_data = param.w->template data<int8_t>();
const auto* b_data = param.bias ? param.bias->template data<int>() : nullptr;
int* o_data = nullptr;
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
o_data = this->tmp_int32_out_->template mutable_data<int>(); auto i_data = param.input->data<int8_t>();
if (m_ > 1) { auto o_data = param.output->mutable_data<int8_t>();
int8_t* packed_in = auto w_data =
static_cast<int8_t*>(ctx.template workspace_data<int8_t>()) + flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
ctx.llc_size() / sizeof(int8_t); const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
lite::arm::math::prepackA_int8( if (flag_trans_bias_) {
packed_in, i_data, k_, 0, m_, 0, k_, false, &ctx); b_data = bias_.data<float>();
lite::arm::math::gemm_prepack_int8(packed_in,
w_data,
b_data,
o_data,
m_,
n_,
k_,
false,
false,
false,
nullptr,
&ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_);
}
} else {
CHECK(transed_weight_);
const auto* t_data = transed_weight_->template data<int8_t>();
lite::arm::math::gemv_int8(t_data,
i_data,
o_data,
false,
n_,
k_,
nullptr,
b_data != nullptr,
b_data,
false);
} }
if (flag_gemm_) {
float i_scale = param.input_scale; CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel "
std::vector<float> weight_scale = param.weight_scale; "must not have bias";
if (Ptype_out == PRECISION(kInt8)) { lite::arm::math::gemm_s8(false,
float o_scale = param.output_scale; false,
param.output->template mutable_data<int8_t>(); m_,
lite::arm::math::trans_tensor_dtype<PRECISION(kInt32), PRECISION(kInt8)>( n_,
tmp_int32_out_, param.output, i_scale, o_scale, weight_scale); k_,
} else if (Ptype_out == PRECISION(kFloat)) { i_data,
param.output->template mutable_data<float>(); w_data,
lite::arm::math::trans_tensor_dtype<PRECISION(kInt32), PRECISION(kFloat)>( o_data,
tmp_int32_out_, param.output, i_scale, 1.f, weight_scale); nullptr,
false,
false,
scale_.data(),
&ctx);
} else { } else {
LOG(ERROR) << "unsupported precision type!!"; for (int i = 0; i < m_; ++i) {
auto i_data_batch = i_data + i * k_;
auto o_data_batch = o_data + i * n_;
lite::arm::math::gemv_int8(w_data,
i_data_batch,
o_data_batch,
false,
n_,
k_,
scale_.data(),
param.bias != nullptr,
b_data,
false,
&ctx);
}
} }
} }
...@@ -228,36 +236,33 @@ void FcComputeInt8<Ptype_out>::Run() { ...@@ -228,36 +236,33 @@ void FcComputeInt8<Ptype_out>::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( typedef paddle::lite::kernels::arm::FcCompute<PRECISION(kFloat),
fc, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::FcCompute, def) PRECISION(kFloat)>
FcCompute_FP32;
typedef paddle::lite::kernels::arm::FcCompute<PRECISION(kInt8),
PRECISION(kFloat)>
FcCompute_int8_fp32;
typedef paddle::lite::kernels::arm::FcCompute<PRECISION(kInt8),
PRECISION(kInt8)>
FcCompute_int8_int8;
REGISTER_LITE_KERNEL(fc, kARM, kFloat, kNCHW, FcCompute_FP32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(fc, kARM, kInt8, kNCHW, FcCompute_int8_int8, int8out)
fc,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::FcComputeInt8<PRECISION(kInt8)>,
int8out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(fc, kARM, kInt8, kNCHW, FcCompute_int8_fp32, fp32out)
fc,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::FcComputeInt8<PRECISION(kFloat)>,
fp32out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize(); .Finalize();
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/backends/arm/math/type_trans.h" #include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
...@@ -22,44 +24,108 @@ namespace lite { ...@@ -22,44 +24,108 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { template <typename Dtype>
public: void naive_transpose(const Dtype* din, Dtype* dout, int m, int n) {
using param_t = operators::FcParam; int k = 0;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
dout[k++] = din[j * n + i];
}
}
}
void PrepareForRun() override; template <PrecisionType PType>
void fc_trans_weights(const Tensor& tin, Tensor* tout);
void Run() override; template <>
void fc_trans_weights<PRECISION(kFloat)>(const Tensor& tin, Tensor* tout) {
CHECK_EQ(tin.dims().size(), 2) << "fc weights size must = 2";
int m = tin.dims()[0];
int n = tin.dims()[1];
tout->Resize({n, m});
auto ptr_in = tin.data<float>();
auto ptr_out = tout->mutable_data<float>();
naive_transpose(ptr_in, ptr_out, m, n);
}
~FcCompute() override { template <>
if (transed_weight_) { void fc_trans_weights<PRECISION(kInt8)>(const Tensor& tin, Tensor* tout) {
delete transed_weight_; CHECK_EQ(tin.dims().size(), 2) << "fc weights size must = 2";
} int m = tin.dims()[0];
}; int n = tin.dims()[1];
tout->Resize({n, m});
auto ptr_in = tin.data<int8_t>();
auto ptr_out = tout->mutable_data<int8_t>();
naive_transpose(ptr_in, ptr_out, m, n);
}
private: template <PrecisionType PType, PrecisionType OutType>
lite::Tensor* transed_weight_{nullptr}; bool check_fc_use_gemm(int m, const std::vector<float>& scale, bool has_bias) {
int m_, n_, k_; return m > 1;
}; }
template <PrecisionType Ptype_out> template <>
class FcComputeInt8 : public KernelLite<TARGET(kARM), PRECISION(kInt8)> { bool check_fc_use_gemm<PRECISION(kInt8), PRECISION(kFloat)>(
int m, const std::vector<float>& scale, bool has_bias) {
CHECK(scale.size() > 0) << "Int8 FC param must has weight_scale";
return m > 1 && scale.size() == 1;
}
template <>
bool check_fc_use_gemm<PRECISION(kInt8), PRECISION(kInt8)>(
int m, const std::vector<float>& scale, bool has_bias) {
CHECK(scale.size() > 0) << "Int8 FC param must has weight_scale";
return m > 1 && scale.size() == 1 && !has_bias;
}
template <PrecisionType PType, PrecisionType OutType>
class FcCompute : public KernelLite<TARGET(kARM), PType> {
public: public:
using param_t = operators::FcParam; using param_t = operators::FcParam;
void PrepareForRun() override; virtual void ReInitWhenNeeded() {
auto& param = this->template Param<operators::FcParam>();
auto x_dims = param.input->dims();
if (last_shape_ == x_dims) {
return;
}
last_shape_ = x_dims;
auto w_dims = param.w->dims();
auto& ctx = this->ctx_->template As<ARMContext>();
void Run() override; CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
~FcComputeInt8() override { m_ = x_dims.Slice(0, param.in_num_col_dims).production();
if (transed_weight_) { k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
delete transed_weight_; CHECK_EQ(k_, w_dims[0]);
n_ = w_dims[1];
CHECK_EQ(k_, static_cast<int>(w_dims[0]));
flag_gemm_ = check_fc_use_gemm<PType, OutType>(
m_, param.weight_scale, param.bias != nullptr);
if (!flag_trans_weights_ && !flag_gemm_) {
flag_trans_weights_ = true;
fc_trans_weights<PType>(*param.w, &weights_);
} }
}; }
virtual void PrepareForRun();
virtual void Run();
~FcCompute() = default;
private: private:
lite::Tensor* transed_weight_{nullptr}; DDim last_shape_;
Tensor* tmp_int32_out_{nullptr}; Tensor weights_;
int m_, n_, k_; Tensor bias_;
bool flag_trans_weights_{false};
bool flag_trans_bias_{false};
bool flag_gemm_{true};
int m_;
int n_;
int k_;
std::vector<float> scale_;
}; };
} // namespace arm } // namespace arm
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/fc_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
#define A(i, j) a[i * lda + j]
#define B(i, j) b[i * ldb + j]
#define C(i, j) c[i * ldc + j]
template <typename T>
void gemm_bias(const T* a,
const int M,
const int K,
const T* b,
const int K_,
const int N,
T* biases,
T* c) {
EXPECT_TRUE(K_ == K && M > 0 && N > 0 && K > 0);
EXPECT_TRUE(a && b && c);
const int lda = K;
const int ldb = N;
const int ldc = N;
for (int m = 0; m < M; ++m) {
for (int n = 0; n < N; ++n) {
C(m, n) = 0.0f;
for (int k = 0; k < K; ++k) {
C(m, n) += A(m, k) * B(k, n);
}
}
}
if (biases) {
for (int m = 0; m < M; ++m) {
for (int n = 0; n < N; ++n) {
C(m, n) += biases[n];
}
}
}
}
template <typename T>
void FillData(T* a,
const int n,
const T lower = static_cast<T>(-2.f),
const T upper = static_cast<T>(2.f)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
TEST(fc_arm, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("fc");
ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front());
}
TEST(fc_arm, init) {
FcCompute fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kARM));
}
TEST(fc_arm, compare_test) {
using T = float;
for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
for (int k : {1, 2, 3, 4}) {
for (bool with_bias : {true, false}) {
VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k
<< (with_bias ? ", with bias" : "");
lite::Tensor x, w, b, out, ref;
x.Resize({m, k});
w.Resize({k, n});
b.Resize({1, n});
out.Resize({m, n});
ref.Resize({m, n});
auto* x_data = x.mutable_data<T>();
auto* w_data = w.mutable_data<T>();
auto* b_data = with_bias ? b.mutable_data<T>() : nullptr;
auto* out_data = out.mutable_data<T>();
auto* ref_data = ref.mutable_data<T>();
FillData<T>(x_data, x.dims().production());
FillData<T>(w_data, w.dims().production());
FillData<T>(out_data, out.dims().production(), 0, 0);
FillData<T>(ref_data, ref.dims().production(), 0, 0);
if (with_bias) {
FillData<T>(b_data, b.dims().production());
}
FcCompute fc;
operators::FcParam param;
param.input = &x;
param.w = &w;
param.bias = with_bias ? &b : nullptr;
param.output = &out;
param.in_num_col_dims = 1;
param.in_mat_dims = x.dims();
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.PrepareForRun();
fc.Run();
gemm_bias<T>(x_data, m, k, w_data, k, n, b_data, ref_data);
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-3);
}
}
}
}
}
}
TEST(fc_arm, num_col_dims) {
using T = float;
for (bool with_bias : {true, false}) {
lite::Tensor x, w, b, out, ref;
x.Resize({1, 2, 3});
w.Resize({3, 4});
b.Resize({1, 4});
out.Resize({2, 4});
ref.Resize({2, 4});
auto* x_data = x.mutable_data<float>();
auto* w_data = w.mutable_data<float>();
auto* b_data = with_bias ? b.mutable_data<T>() : nullptr;
auto* out_data = out.mutable_data<T>();
auto* ref_data = ref.mutable_data<T>();
FillData<T>(x_data, x.dims().production());
FillData<T>(w_data, w.dims().production());
FillData<T>(out_data, out.dims().production(), 0, 0);
FillData<T>(ref_data, ref.dims().production(), 0, 0);
if (with_bias) {
FillData<T>(b_data, b.dims().production());
}
FcCompute fc;
operators::FcParam param;
param.input = &x;
param.w = &w;
param.bias = with_bias ? &b : nullptr;
param.output = &out;
param.in_num_col_dims = 2;
param.in_mat_dims = x.dims();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
DeviceInfo::Init();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.PrepareForRun();
fc.Run();
gemm_bias<T>(x_data, 2, 3, w_data, 3, 4, b_data, ref_data);
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-3);
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
...@@ -56,7 +56,7 @@ void MulCompute::Run() { ...@@ -56,7 +56,7 @@ void MulCompute::Run() {
} else { } else {
constexpr bool is_tranposed_y = false; constexpr bool is_tranposed_y = false;
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
int hblock = lite::arm::math::get_hblock(ctx.arch()); int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_ + hblock - 1) / hblock); int m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); ctx.ExtendWorkspace(m_round * k_ * sizeof(float));
......
...@@ -82,7 +82,6 @@ struct FcParam { ...@@ -82,7 +82,6 @@ struct FcParam {
lite::Tensor* output{nullptr}; lite::Tensor* output{nullptr};
lite::DDim in_mat_dims; lite::DDim in_mat_dims;
int in_num_col_dims{1}; int in_num_col_dims{1};
bool weight_transposed{false};
// for int8 // for int8
WITH_INT8_CONFIG WITH_INT8_CONFIG
}; };
......
add_subdirectory(kernels) add_subdirectory(kernels)
add_subdirectory(math)
...@@ -31,8 +31,6 @@ if(LITE_BUILD_EXTRA) ...@@ -31,8 +31,6 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif() endif()
lite_cc_test(test_sgemm SRCS test_sgemm.cc DEPS ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h" #include "lite/core/arena/framework.h"
#include "lite/tests/kernels/fill_data.h" #include "lite/tests/utils/fill_data.h"
#include "lite/tests/kernels/test_funcs.h" #include "lite/tests/utils/naive_math_impl.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h" #include "lite/core/arena/framework.h"
#include "lite/tests/kernels/fill_data.h" #include "lite/tests/utils/fill_data.h"
#include "lite/tests/kernels/test_funcs.h" #include "lite/tests/utils/naive_math_impl.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(conv_int8_compute_test SRCS conv_int8_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/core/context.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/naive_math_impl.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/tests/utils/timer.h"
#ifdef LITE_WITH_ARM
#include "lite/kernels/arm/conv_compute.h"
#endif // LITE_WITH_ARM
DEFINE_int32(cluster, 0, "cluster id");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size");
DEFINE_int32(in_channel, 32, "input channel");
DEFINE_int32(in_height, 112, "input height");
DEFINE_int32(in_width, 112, "input width");
DEFINE_int32(out_channel, 32, "output channel");
DEFINE_int32(group, 1, "group");
DEFINE_int32(kernel_h, 3, "kernel height");
DEFINE_int32(kernel_w, 3, "kernel width");
DEFINE_int32(pad_h, 1, "pad height");
DEFINE_int32(pad_w, 1, "pad width");
DEFINE_int32(stride_h, 1, "stride height");
DEFINE_int32(stride_w, 1, "stride width");
DEFINE_int32(dila_h, 1, "dilation height");
DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_bool(flag_relu, true, "do relu");
DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam;
DDim compute_out_dim(const DDim& dim_in,
const paddle::lite::operators::ConvParam& param) {
DDim dim_out = dim_in;
dim_out[1] = param.filter->dims()[0];
auto kernel_h = param.filter->dims()[2];
auto kernel_w = param.filter->dims()[3];
auto h = dim_in[2];
auto w = dim_in[3];
int dila_h = param.dilations[0];
int dila_w = param.dilations[1];
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride_h = param.strides[0];
int stride_w = param.strides[1];
auto kernel_exten = dila_h * (kernel_h - 1) + 1;
auto hout = (h + 2 * pad_h - kernel_exten) / stride_h + 1;
kernel_exten = dila_w * (kernel_w - 1) + 1;
auto wout = (w + 2 * pad_w - kernel_exten) / stride_w + 1;
dim_out[2] = hout;
dim_out[3] = wout;
return dim_out;
}
#ifdef LITE_WITH_ARM
void test_conv_fp32(const std::vector<DDim>& input_dims,
const DDim& weight_dim,
int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
const std::vector<int>& thread_num,
const std::vector<int>& cluster_id) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
ConvParam param;
param.x = new Tensor;
param.x->set_precision(PRECISION(kFloat));
param.filter = new Tensor;
param.filter->Resize(weight_dim);
param.filter->set_precision(PRECISION(kFloat));
if (flag_bias) {
param.bias = new Tensor;
param.bias->Resize({weight_dim[0]});
param.bias->set_precision(PRECISION(kFloat));
}
param.strides = strides;
param.paddings = pads;
param.dilations = dilas;
param.fuse_relu = flag_relu;
param.groups = group;
param.output = new Tensor;
param.output->set_precision(PRECISION(kFloat));
paddle::lite::fill_tensor_rand(*param.filter, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.filter, 1.f);
if (flag_bias) {
paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.bias, 1.f);
}
auto wptr = param.filter->data<float>();
auto bias_ptr = flag_bias ? param.bias->data<float>() : nullptr;
for (auto& cls : cluster_id) {
for (auto& th : thread_num) {
paddle::lite::kernels::arm::ConvCompute<PRECISION(kFloat),
PRECISION(kFloat)>
conv;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
/// set param and context
for (auto& dim_in : input_dims) {
param.x->Resize(dim_in);
DDim out_tmp_dims = compute_out_dim(dim_in, param);
if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) {
continue;
}
param.output->Resize(out_tmp_dims);
break;
}
conv.SetParam(param);
conv.SetContext(std::move(ctx1));
/// prepare for run
conv.PrepareForRun();
for (auto& dim_in : input_dims) {
CHECK_EQ(weight_dim[1] * group, dim_in[1])
<< "input channel must equal to weights channel";
DDim dim_out = compute_out_dim(dim_in, param);
if (dim_out[2] < 1 || dim_out[3] < 1) {
continue;
}
param.x->Resize(dim_in);
param.output->Resize(dim_out);
paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.x, 1.f);
auto din = param.x->data<float>();
Tensor tout_basic;
if (FLAGS_check_result) {
tout_basic.set_precision(PRECISION(kFloat));
tout_basic.Resize(dim_out);
fill_tensor_const(tout_basic, 0.f);
auto dout_basic = tout_basic.mutable_data<float>();
conv_basic<float, float>(din,
dout_basic,
dim_in[0],
dim_out[1],
dim_out[2],
dim_out[3],
dim_in[1],
dim_in[2],
dim_in[3],
wptr,
bias_ptr,
group,
weight_dim[3],
weight_dim[2],
strides[1],
strides[0],
dilas[1],
dilas[0],
pads[1],
pads[0],
flag_bias,
flag_relu);
}
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
conv.Launch();
}
/// compute
lite::test::Timer t0;
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
conv.Launch();
t0.end();
}
double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] *
weight_dim[3] / param.groups;
LOG(INFO) << "conv fp32: input shape: " << dim_in << ", output shape"
<< dim_out << ",running time, avg: " << t0.get_average_ms()
<< ", min time: " << t0.get_min_time()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time();
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tout_basic, *param.output, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-3f) {
if (max_diff > 5e-4f) {
LOG(WARNING) << "basic result";
print_tensor(tout_basic);
LOG(WARNING) << "saber result";
print_tensor(*param.output);
Tensor tdiff;
tdiff.Resize(tout_basic.dims());
tdiff.set_precision(PRECISION(kFloat));
tensor_diff(tout_basic, *param.output, tdiff);
print_tensor(tdiff);
LOG(FATAL) << "test fp32 conv: input: " << dim_in
<< ", output: " << dim_out
<< ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", cluster: " << cls
<< " failed!!\n";
}
}
}
LOG(INFO) << "test fp32 conv: input: " << dim_in
<< ", output: " << dim_out << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", cluster: " << cls
<< " successed!!\n";
}
}
}
delete param.x;
delete param.filter;
delete param.output;
delete param.bias;
}
#else
void test_conv_fp32(const std::vector<DDim>& input_dims,
const DDim& weight_dim,
int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
const std::vector<int>& thread_num,
const std::vector<int>& cluster_id) {}
#endif // LITE_WITH_ARM
#if 1 /// 3x3dw
TEST(TestConv3x3DW, test_conv3x3_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 19, 28, 32, 75}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
test_conv_fp32(dims,
weights_dim,
c,
{stride, stride},
{pad, pad},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// 3x3dw
#if 1 /// 5x5dw
TEST(TestConv5x5DW, test_conv5x5_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 19, 28, 32, 75}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
test_conv_fp32(dims,
weights_dim,
c,
{stride, stride},
{pad, pad},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// 5x5dw
#if 1 /// conv1x1s1
TEST(TestConv1x1s1, test_conv1x1s1) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 11, 32}) {
for (auto& cout : {1, 5, 16, 37}) {
for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
std::vector<DDim> dims;
if (cin % g != 0 || cout % g != 0) {
continue;
}
DDim weights_dim({cout, cin / g, 1, 1});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 32, 56, 1}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
test_conv_fp32(dims,
weights_dim,
g,
{1, 1},
{0, 0},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// conv1x1s1
#if 1 /// conv3x3s1
TEST(TestConv3x3s1, test_conv_3x3s1) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32, 48}) {
for (auto& cout : {1, 5, 8, 32, 48}) {
for (auto& pad : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 56, 32}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
test_conv_fp32(dims,
weights_dim,
1,
{1, 1},
{pad, pad},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// conv3x3s1
#if 1 /// conv3x3s2
TEST(TestConv3x3s2, test_conv_3x3s2) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32}) {
for (auto& cout : {1, 5, 8, 32}) {
for (auto& pad : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 75, 56, 32}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
test_conv_fp32(dims,
weights_dim,
1,
{2, 2},
{pad, pad},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// conv3x3s2
#if 1 /// random param conv
TEST(TestConvRand, test_conv_rand) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 16}) {
for (auto& cout : {1, 5, 8, 16}) {
for (auto& g : {1, 2}) {
for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2}) {
for (auto& dila : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
if (cin % g != 0 || cout % g != 0) {
continue;
}
std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 19, 32, 28}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
test_conv_fp32(dims,
weights_dim,
g,
{stride, stride},
{pad, pad},
{dila, dila},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
}
}
}
}
}
#endif /// random param conv
#if 1 /// custom
TEST(TestConvCustom, test_conv_fp32_custom_size) {
CHECK_EQ(FLAGS_in_channel % FLAGS_group, 0)
<< "input channel must be divided by group";
CHECK_EQ(FLAGS_out_channel % FLAGS_group, 0)
<< "num_output must be divided by group";
test_conv_fp32(
{DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})},
DDim({FLAGS_out_channel,
FLAGS_in_channel / FLAGS_group,
FLAGS_kernel_h,
FLAGS_kernel_w}),
FLAGS_group,
{FLAGS_stride_h, FLAGS_stride_w},
{FLAGS_pad_h, FLAGS_pad_w},
{FLAGS_dila_h, FLAGS_dila_w},
FLAGS_flag_bias,
FLAGS_flag_relu,
{FLAGS_threads},
{FLAGS_cluster});
}
#endif // custom
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/core/context.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/naive_math_impl.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/tests/utils/timer.h"
#ifdef LITE_WITH_ARM
#include "lite/kernels/arm/conv_compute.h"
#endif // LITE_WITH_ARM
DEFINE_int32(cluster, 0, "cluster id");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size");
DEFINE_int32(in_channel, 32, "input channel");
DEFINE_int32(in_height, 112, "input height");
DEFINE_int32(in_width, 112, "input width");
DEFINE_int32(out_channel, 32, "output channel");
DEFINE_int32(group, 1, "group");
DEFINE_int32(kernel_h, 3, "kernel height");
DEFINE_int32(kernel_w, 3, "kernel width");
DEFINE_int32(pad_h, 1, "pad height");
DEFINE_int32(pad_w, 1, "pad width");
DEFINE_int32(stride_h, 1, "stride height");
DEFINE_int32(stride_w, 1, "stride width");
DEFINE_int32(dila_h, 1, "dilation height");
DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_bool(flag_relu, true, "do relu");
DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam;
DDim compute_out_dim(const DDim& dim_in,
const paddle::lite::operators::ConvParam& param) {
DDim dim_out = dim_in;
dim_out[1] = param.filter->dims()[0];
auto kernel_h = param.filter->dims()[2];
auto kernel_w = param.filter->dims()[3];
auto h = dim_in[2];
auto w = dim_in[3];
int dila_h = param.dilations[0];
int dila_w = param.dilations[1];
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride_h = param.strides[0];
int stride_w = param.strides[1];
auto kernel_exten = dila_h * (kernel_h - 1) + 1;
auto hout = (h + 2 * pad_h - kernel_exten) / stride_h + 1;
kernel_exten = dila_w * (kernel_w - 1) + 1;
auto wout = (w + 2 * pad_w - kernel_exten) / stride_w + 1;
dim_out[2] = hout;
dim_out[3] = wout;
return dim_out;
}
template <paddle::lite::PrecisionType ptype>
void get_conv_param(const DDim& dim_w,
int g,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dila,
bool flag_bias,
bool flag_relu,
ConvParam* param) {
param->x = new Tensor;
param->x->set_precision(PRECISION(kInt8));
param->filter = new Tensor;
param->filter->Resize(dim_w);
param->filter->set_precision(PRECISION(kInt8));
if (flag_bias) {
param->bias = new Tensor;
param->bias->Resize({dim_w[0]});
param->bias->set_precision(PRECISION(kFloat));
}
param->strides = strides;
param->paddings = pads;
param->dilations = dila;
param->fuse_relu = flag_relu;
param->groups = g;
param->output = new Tensor;
param->output->set_precision(ptype);
}
void release_param(ConvParam* param) {
delete param->x;
delete param->filter;
delete param->output;
delete param->bias;
}
#ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h"
void test_conv_int8(const std::vector<DDim>& input_dims,
const DDim& weight_dim,
int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
const std::vector<int>& thread_num,
const std::vector<int>& cluster_id) {
paddle::lite::DeviceInfo::Init();
ConvParam param_int8_out;
ConvParam param_fp32_out;
get_conv_param<PRECISION(kInt8)>(weight_dim,
group,
strides,
pads,
dilas,
flag_bias,
flag_relu,
&param_int8_out);
get_conv_param<PRECISION(kFloat)>(weight_dim,
group,
strides,
pads,
dilas,
flag_bias,
flag_relu,
&param_fp32_out);
Tensor weight_fp32;
Tensor bias_fp32;
weight_fp32.Resize(weight_dim);
paddle::lite::fill_tensor_rand(*param_int8_out.filter, -127, 127);
param_fp32_out.filter->CopyDataFrom(*param_int8_out.filter);
if (flag_bias) {
auto dim_b = param_int8_out.bias->dims();
bias_fp32.Resize(dim_b);
paddle::lite::fill_tensor_rand(*param_int8_out.bias, -1.f, 1.f);
param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias);
bias_fp32.CopyDataFrom(*param_int8_out.bias);
}
std::vector<float> scale_in{1.f / 127};
std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f};
std::vector<float> scale_w(weight_dim[0], 1.f / 127);
param_int8_out.input_scale = scale_in[0];
param_int8_out.output_scale = scale_out[0];
param_int8_out.weight_scale = scale_w;
param_fp32_out.input_scale = scale_in[0];
param_fp32_out.output_scale = scale_out[0];
param_fp32_out.weight_scale = scale_w;
auto wptr_fp32 = weight_fp32.mutable_data<float>();
auto bptr_fp32 = flag_bias ? bias_fp32.data<float>() : nullptr;
paddle::lite::arm::math::int8_to_fp32(param_int8_out.filter->data<int8_t>(),
wptr_fp32,
scale_w.data(),
weight_dim[0],
1,
weight_dim.count(1, 4));
for (auto& cls : cluster_id) {
for (auto& th : thread_num) {
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
std::unique_ptr<paddle::lite::KernelContext> ctx2(
new paddle::lite::KernelContext);
auto& ctx_tmp1 = ctx1->As<paddle::lite::ARMContext>();
ctx_tmp1.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
auto& ctx_tmp2 = ctx2->As<paddle::lite::ARMContext>();
ctx_tmp2.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
paddle::lite::kernels::arm::ConvCompute<PRECISION(kInt8),
PRECISION(kInt8)>
conv_int8_int8;
paddle::lite::kernels::arm::ConvCompute<PRECISION(kInt8),
PRECISION(kFloat)>
conv_int8_fp32;
conv_int8_int8.SetContext(std::move(ctx1));
conv_int8_fp32.SetContext(std::move(ctx2));
/// set param and context
for (auto& dim_in : input_dims) {
param_int8_out.x->Resize(dim_in);
DDim out_tmp_dims = compute_out_dim(dim_in, param_int8_out);
if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) {
continue;
}
param_fp32_out.x->Resize(dim_in);
param_int8_out.output->Resize(out_tmp_dims);
param_fp32_out.output->Resize(out_tmp_dims);
break;
}
conv_int8_int8.SetParam(param_int8_out);
conv_int8_fp32.SetParam(param_fp32_out);
/// prepare for run
conv_int8_int8.PrepareForRun();
conv_int8_fp32.PrepareForRun();
for (auto& dim_in : input_dims) {
CHECK_EQ(weight_dim[1] * group, dim_in[1])
<< "input channel must equal to weights channel";
DDim dim_out = compute_out_dim(dim_in, param_int8_out);
if (dim_out[2] < 1 || dim_out[3] < 1) {
continue;
}
delete param_fp32_out.output;
param_fp32_out.output = new Tensor;
param_fp32_out.output->set_precision(PRECISION(kFloat));
delete param_int8_out.output;
param_int8_out.output = new Tensor;
param_int8_out.output->set_precision(PRECISION(kInt8));
param_int8_out.x->Resize(dim_in);
param_int8_out.output->Resize(dim_out);
param_fp32_out.x->Resize(dim_in);
param_fp32_out.output->Resize(dim_out);
Tensor tin_fp32;
tin_fp32.Resize(dim_in);
tin_fp32.set_precision(PRECISION(kFloat));
Tensor tout_basic_fp32;
Tensor tout_basic_int8;
paddle::lite::fill_tensor_rand(*param_int8_out.x, -127, 127);
param_fp32_out.x->CopyDataFrom(*param_int8_out.x);
auto din_fp32 = tin_fp32.mutable_data<float>();
paddle::lite::arm::math::int8_to_fp32(param_int8_out.x->data<int8_t>(),
din_fp32,
scale_in.data(),
1,
1,
dim_in.production());
if (FLAGS_check_result) {
tout_basic_fp32.set_precision(PRECISION(kFloat));
tout_basic_fp32.Resize(dim_out);
tout_basic_int8.set_precision(PRECISION(kInt8));
tout_basic_int8.Resize(dim_out);
fill_tensor_const(tout_basic_fp32, 0.f);
auto dout_basic_fp32 = tout_basic_fp32.mutable_data<float>();
auto dout_basic_int8 = tout_basic_int8.mutable_data<int8_t>();
conv_basic<float, float>(din_fp32,
dout_basic_fp32,
dim_in[0],
dim_out[1],
dim_out[2],
dim_out[3],
dim_in[1],
dim_in[2],
dim_in[3],
wptr_fp32,
bptr_fp32,
group,
weight_dim[3],
weight_dim[2],
strides[1],
strides[0],
dilas[1],
dilas[0],
pads[1],
pads[0],
flag_bias,
flag_relu);
paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32,
dout_basic_int8,
scale_out.data(),
1,
1,
dim_out.production());
}
double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] *
weight_dim[3] / group;
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
conv_int8_int8.Launch();
}
/// compute fp32 output
lite::test::Timer t0;
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
conv_int8_fp32.Launch();
t0.end();
}
LOG(INFO) << "int8 conv, fp32 output: output shape" << dim_out
<< ",running time, avg: " << t0.get_average_ms()
<< ", min time: " << t0.get_min_time()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time();
/// compute int8 output
t0.clear();
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
conv_int8_int8.Launch();
t0.end();
}
LOG(INFO) << "int8 conv, int8 output: output shape" << dim_out
<< ",running time, avg: " << t0.get_average_ms()
<< ", min time: " << t0.get_min_time()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time();
/// compare result fp32 output
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(
tout_basic_fp32, *param_fp32_out.output, max_ratio, max_diff);
LOG(INFO) << "FP32 compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-5f) {
if (max_diff > 5e-5f) {
LOG(WARNING) << "basic result";
print_tensor(tout_basic_fp32);
LOG(WARNING) << "saber result";
print_tensor(*param_fp32_out.output);
Tensor tdiff;
tdiff.Resize(tout_basic_fp32.dims());
tdiff.set_precision(PRECISION(kFloat));
tensor_diff(tout_basic_fp32, *param_fp32_out.output, tdiff);
print_tensor(tdiff);
release_param(&param_int8_out);
release_param(&param_fp32_out);
LOG(FATAL) << "test int8 conv, fp32 out: input: " << dim_in
<< ", output: " << dim_out
<< ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", cluster: " << cls
<< " failed!!\n";
}
}
}
/// compare result int8 output
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
// ! int8
tensor_cmp_host(
tout_basic_int8, *param_int8_out.output, max_ratio, max_diff);
LOG(INFO) << "int8 compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (fabs(max_diff) > 0) {
Tensor tdiff;
tdiff.Resize(tout_basic_int8.dims());
tdiff.set_precision(PRECISION(kInt8));
tensor_diff(tout_basic_int8, *param_int8_out.output, tdiff);
auto ptr = tdiff.data<int8_t>();
auto ptr_basic_fp32 = tout_basic_fp32.data<float>();
float count = 0;
bool check = true;
for (int i = 0; i < tdiff.numel(); ++i) {
if (abs(ptr[i]) > 1) {
check = false;
LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i]
<< ", after scale: "
<< ptr_basic_fp32[i] / scale_out[0];
break;
}
if (ptr[i] != 0) {
LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i]
<< ", after scale: "
<< ptr_basic_fp32[i] / scale_out[0];
count += 1;
}
}
check =
check &&
count < std::max(10, static_cast<int>(0.01 * tdiff.numel()));
if (!check) {
LOG(WARNING) << "int8 basic result";
print_tensor(tout_basic_int8);
LOG(WARNING) << "int8 saber result";
print_tensor(*param_int8_out.output);
LOG(WARNING) << "int8 diff tensor";
print_tensor(tdiff);
release_param(&param_int8_out);
release_param(&param_fp32_out);
LOG(FATAL) << "test int8 conv, int8 out: input: " << dim_in
<< ", output: " << dim_out
<< ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", cluster: " << cls
<< " failed!!\n";
}
}
}
LOG(INFO) << "test int8 conv: input: " << dim_in
<< ", output: " << dim_out << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", cluster: " << cls
<< " successed!!\n";
}
}
}
release_param(&param_int8_out);
release_param(&param_fp32_out);
}
#else
void test_conv_int8(const std::vector<DDim>& input_dims,
const DDim& weight_dim,
int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
const std::vector<int>& thread_num,
const std::vector<int>& cluster_id) {}
#endif // LITE_WITH_ARM
#if 1 /// 3x3dw
TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 19, 75, 32, 28}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
test_conv_int8(dims,
weights_dim,
c,
{stride, stride},
{pad, pad},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// 3x3dw
#if 0 /// 5x5dw
TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) {
for (auto &h : {1, 3, 15, 19, 28, 32, 75}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
test_conv_int8(dims,
weights_dim,
c,
{stride, stride},
{pad, pad},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// 5x5dw
#if 1 /// conv1x1s1
TEST(TestConv1x1s1Int8, test_conv1x1s1) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 11, 32}) {
for (auto& cout : {1, 5, 16, 37}) {
for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
std::vector<DDim> dims;
if (cin % g != 0 || cout % g != 0) {
continue;
}
DDim weights_dim({cout, cin / g, 1, 1});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 32, 56, 1}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
test_conv_int8(dims,
weights_dim,
g,
{1, 1},
{0, 0},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// conv1x1s1
#if 1 /// conv3x3s1
TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32, 48}) {
for (auto& cout : {1, 5, 8, 32, 48}) {
for (auto& pad : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 56, 32}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
test_conv_int8(dims,
weights_dim,
1,
{1, 1},
{pad, pad},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// conv3x3s1
#if 1 /// conv3x3s2
TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32}) {
for (auto& cout : {1, 5, 8, 32}) {
for (auto& pad : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 75, 56, 32}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
test_conv_int8(dims,
weights_dim,
1,
{2, 2},
{pad, pad},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
#endif /// conv3x3s2
#if 1 /// random param conv
TEST(TestConvRandInt8, test_conv_rand) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 16}) {
for (auto& cout : {1, 5, 8, 16}) {
for (auto& g : {1, 2}) {
for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2}) {
for (auto& dila : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
if (cin % g != 0 || cout % g != 0) {
continue;
}
std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 19, 32, 28}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
test_conv_int8(dims,
weights_dim,
g,
{stride, stride},
{pad, pad},
{dila, dila},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_cluster});
}
}
}
}
}
}
}
}
}
}
}
}
#endif /// random param conv
#if 1 /// custom
TEST(TestConvCustomInt8, test_conv_custom_size) {
CHECK_EQ(FLAGS_in_channel % FLAGS_group, 0)
<< "input channel must be divided by group";
CHECK_EQ(FLAGS_out_channel % FLAGS_group, 0)
<< "num_output must be divided by group";
test_conv_int8(
{DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})},
DDim({FLAGS_out_channel,
FLAGS_in_channel / FLAGS_group,
FLAGS_kernel_h,
FLAGS_kernel_w}),
FLAGS_group,
{FLAGS_stride_h, FLAGS_stride_w},
{FLAGS_pad_h, FLAGS_pad_w},
{FLAGS_dila_h, FLAGS_dila_w},
FLAGS_flag_bias,
FLAGS_flag_relu,
{FLAGS_threads},
{FLAGS_cluster});
}
#endif // custom
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
#ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h"
#endif // LITE_WITH_ARM
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/tests/utils/timer.h"
typedef paddle::lite::Tensor Tensor;
DEFINE_int32(cluster, 0, "cluster id");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemm: M");
DEFINE_int32(N, 512, "gemm: N");
DEFINE_int32(K, 512, "gemm: K");
DEFINE_bool(traA, false, "gemm: A transpose");
DEFINE_bool(traB, false, "gemm: B transpose");
DEFINE_bool(flag_relu, false, "do relu");
DEFINE_bool(flag_bias, false, "with bias");
bool test_gemm_int8(bool tra,
bool trb,
int m,
int n,
int k,
bool has_bias,
bool has_relu,
int cls,
int ths) {
Tensor ta;
Tensor tb;
Tensor tc_int8;
Tensor tc_fp32;
Tensor tc_basic_int8;
Tensor tc_basic_fp32;
Tensor tbias;
int lda = tra ? m : k;
int ldb = trb ? k : n;
int ldc = n;
ta.Resize({m, k});
tb.Resize({k, n});
tc_int8.Resize({m, n});
tc_fp32.Resize({m, n});
tc_basic_int8.Resize({m, n});
tc_basic_fp32.Resize({m, n});
tbias.Resize({m});
ta.set_precision(PRECISION(kInt8));
tb.set_precision(PRECISION(kInt8));
tc_int8.set_precision(PRECISION(kInt8));
tc_fp32.set_precision(PRECISION(kFloat));
tc_basic_int8.set_precision(PRECISION(kInt8));
tc_basic_fp32.set_precision(PRECISION(kFloat));
tbias.set_precision(PRECISION(kFloat));
fill_tensor_rand(ta, -127, 127);
fill_tensor_rand(tb, -127, 127);
fill_tensor_rand(tbias, -1.f, 1.f);
std::vector<float> scale_a(static_cast<size_t>(m), 1.f / 127);
std::vector<float> scale_b = {1.f / 127};
std::vector<float> scale_c = {k / 127.f};
std::vector<float> scale_merge_fp32(static_cast<size_t>(m));
std::vector<float> scale_merge_int8(static_cast<size_t>(m));
for (int j = 0; j < m; ++j) {
scale_merge_fp32[j] = scale_a[j] * scale_b[0];
scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0];
}
auto da = ta.mutable_data<int8_t>();
auto db = tb.mutable_data<int8_t>();
auto dc_int8 = tc_int8.mutable_data<int8_t>();
auto dc_fp32 = tc_fp32.mutable_data<float>();
auto dc_basic_int8 = tc_basic_int8.mutable_data<int8_t>();
auto dc_basic_fp32 = tc_basic_fp32.mutable_data<float>();
auto dbias = tbias.mutable_data<float>();
LOG(INFO) << "gemm_int8 M: " << m << ", N: " << n << ", K: " << k
<< ", transA: " << (tra ? "true" : "false")
<< ", transB: " << (trb ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM
if (FLAGS_check_result) {
Tensor ta_fp32;
Tensor tb_fp32;
ta_fp32.Resize({m, k});
ta_fp32.set_precision(PRECISION(kFloat));
tb_fp32.Resize({k, n});
tb_fp32.set_precision(PRECISION(kFloat));
auto da_fp32 = ta_fp32.mutable_data<float>();
auto db_fp32 = tb_fp32.mutable_data<float>();
paddle::lite::arm::math::int8_to_fp32(
da, da_fp32, scale_a.data(), 1, 1, ta.numel());
paddle::lite::arm::math::int8_to_fp32(
db, db_fp32, scale_b.data(), 1, 1, tb.numel());
basic_gemm(tra,
trb,
m,
n,
k,
1.f,
da_fp32,
lda,
db_fp32,
ldb,
0.f,
dc_basic_fp32,
ldc,
dbias,
has_bias,
has_relu);
paddle::lite::arm::math::fp32_to_int8(dc_basic_fp32,
dc_basic_int8,
scale_c.data(),
1,
1,
tc_basic_fp32.numel());
}
lite::test::Timer t0;
//! compute
double ops = 2.0 * m * n * k;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
//! prepack
Tensor tpackedA;
int hblock = paddle::lite::arm::math::get_hblock_int8(&ctx);
int round_up_a = ((hblock + m - 1) / hblock) * hblock;
int round_up_k = 4 * ((k + 3) / 4);
tpackedA.Resize({round_up_a * round_up_k});
paddle::lite::arm::math::prepackA_int8(
tpackedA.mutable_data<int8_t>(), da, lda, 0, m, 0, k, tra, &ctx);
/// warmup
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::gemm_prepack_int8(tpackedA.data<int8_t>(),
db,
dbias,
dc_fp32,
m,
n,
k,
has_bias,
has_relu,
trb,
scale_merge_fp32.data(),
&ctx);
}
/// int8 output compute
Tensor tbias_int8;
tbias_int8.Resize(tbias.dims());
tbias_int8.set_precision(PRECISION(kFloat));
auto dbias_int8 = tbias_int8.mutable_data<float>();
for (int l = 0; l < tbias_int8.numel(); ++l) {
dbias_int8[l] = dbias[l] / scale_c[0];
}
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
paddle::lite::arm::math::gemm_prepack_int8(tpackedA.data<int8_t>(),
db,
dbias_int8,
dc_int8,
m,
n,
k,
has_bias,
has_relu,
trb,
scale_merge_int8.data(),
&ctx);
t0.end();
}
LOG(INFO) << "gemm_int8_int8 output: M: " << m << ", N: " << n << ", K: " << k
<< ", cluster: " << cls << ", threads: " << ths
<< ", GOPS: " << ops * 1e-9f
<< " GOPS, avg time: " << t0.get_average_ms()
<< " ms, min time: " << t0.get_min_time()
<< " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms()
<< " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time()
<< " GOPs";
/// fp32 output compute
t0.clear();
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
paddle::lite::arm::math::gemm_prepack_int8(tpackedA.data<int8_t>(),
db,
dbias,
dc_fp32,
m,
n,
k,
has_bias,
has_relu,
trb,
scale_merge_fp32.data(),
&ctx);
t0.end();
}
LOG(INFO) << "gemm_int8_fp32 output: M: " << m << ", N: " << n << ", K: " << k
<< ", cluster: " << cls << ", threads: " << ths
<< ", GOPS: " << ops * 1e-9f
<< " GOPS, avg time: " << t0.get_average_ms()
<< " ms, min time: " << t0.get_min_time()
<< " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms()
<< " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time()
<< " GOPs";
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
/// fp32 result
tensor_cmp_host(tc_basic_fp32, tc_fp32, max_ratio, max_diff);
LOG(INFO) << "fp32 compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
Tensor tdiff;
tdiff.set_precision(PRECISION(kFloat));
tdiff.Resize(tc_fp32.dims());
tensor_diff(tc_basic_fp32, tc_fp32, tdiff);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic_fp32);
LOG(INFO) << "saber result: ";
print_tensor(tc_fp32);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
return false;
}
/// int8 result
max_ratio = 0;
max_diff = 0;
tensor_cmp_host(tc_basic_int8, tc_int8, max_ratio, max_diff);
LOG(INFO) << "int8 compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (fabs(max_ratio) > 1e-4f) {
Tensor tdiff;
tdiff.Resize(tc_int8.dims());
tdiff.set_precision(PRECISION(kInt8));
tensor_diff(tc_basic_int8, tc_int8, tdiff);
auto ptr = tdiff.data<int8_t>();
auto ptr_basic_fp32 = tc_basic_fp32.data<float>();
float count = 0;
bool check = true;
for (int i = 0; i < tdiff.numel(); ++i) {
if (abs(ptr[i]) > 1) {
check = false;
LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i]
<< ", after scale: " << ptr_basic_fp32[i] / scale_c[0];
break;
}
if (ptr[i] != 0) {
LOG(ERROR) << "basic float data: " << ptr_basic_fp32[i]
<< ", after scale: " << ptr_basic_fp32[i] / scale_c[0];
count += 1;
}
}
check =
check && count < std::max(10, static_cast<int>(0.01 * tdiff.numel()));
if (!check) {
LOG(WARNING) << "int8 basic result";
print_tensor(tc_basic_int8);
LOG(WARNING) << "int8 saber result";
print_tensor(tc_int8);
LOG(WARNING) << "int8 diff tensor";
print_tensor(tdiff);
return false;
}
}
}
#endif
return true;
}
TEST(TestLiteGemmInt8, gemm_prepacked_int8) {
if (FLAGS_basic_test) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm test";
for (auto& m : {1, 3, 8, 32, 397}) {
for (auto& n : {1, 3, 13, 141, 512, 789}) {
for (auto& k : {1, 3, 8, 59, 234}) {
for (auto& tra : {false, true}) {
for (auto& trb : {false, true}) {
for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) {
for (auto& th : {1, 2, 4}) {
auto flag = test_gemm_int8(tra,
trb,
m,
n,
k,
has_bias,
has_relu,
FLAGS_cluster,
th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n
<< ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false")
<< ", trans B: " << (trb ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n
<< ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false")
<< ", trans B: " << (trb ? "true" : "false")
<< " failed\n";
}
}
}
}
}
}
}
}
}
}
}
TEST(TestGemmInt8Custom, gemm_prepacked_int8_custom) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
auto flag = test_gemm_int8(FLAGS_traA,
FLAGS_traB,
FLAGS_M,
FLAGS_N,
FLAGS_K,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_cluster,
FLAGS_threads);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", k=" << FLAGS_K << ", trans A: " << FLAGS_traA
<< ", trans B: " << FLAGS_traB << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K
<< ", trans A: " << FLAGS_traA << ", trans B: " << FLAGS_traB
<< ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu
<< " passed!!";
}
...@@ -12,60 +12,43 @@ ...@@ -12,60 +12,43 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// #include <gflags/gflags.h>
// Created by Li,Xiaoyang(SYS) on 2019-07-25. #include <gtest/gtest.h>
// #include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
#include "lite/tests/kernels/fill_data.h"
#include "lite/tests/kernels/test_funcs.h"
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
#endif #endif // LITE_WITH_ARM
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
int g_cluster = 0; #include "lite/tests/utils/tensor_utils.h"
int g_threads = 1; #include "lite/tests/utils/timer.h"
bool g_basic_test = false; typedef paddle::lite::Tensor Tensor;
int g_M = 512; DEFINE_int32(cluster, 0, "cluster id");
int g_N = 512; DEFINE_int32(threads, 1, "threads num");
int g_K = 512; DEFINE_int32(warmup, 0, "warmup times");
bool g_traA = false; DEFINE_int32(repeats, 1, "repeats times");
bool g_traB = false; DEFINE_bool(basic_test, false, "do all tests");
bool g_flag_relu = false; DEFINE_bool(check_result, true, "check the result");
bool g_flag_bias = false;
int g_test_iter = 1;
int g_warmup_iter = 0;
bool g_compare_result = true;
int g_offset_a = 10; DEFINE_int32(M, 512, "gemm: M");
int g_offset_b = 10; DEFINE_int32(N, 512, "gemm: N");
int g_offset_c = 10; DEFINE_int32(K, 512, "gemm: K");
float g_alpha = 1.f; DEFINE_bool(traA, false, "gemm: A transpose");
float g_beta = 0.f; DEFINE_bool(traB, false, "gemm: B transpose");
const int MALLOC_ALIGN = 16; DEFINE_int32(offset_a, 0, "A offset");
DEFINE_int32(offset_b, 0, "B offset");
DEFINE_int32(offset_c, 0, "C offset");
static void* fast_malloc1(size_t size) { DEFINE_double(alpha, 1.0, "alpha");
size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; DEFINE_double(beta, 0.0, "beta");
char* p;
p = static_cast<char*>(malloc(offset + size));
if (!p) {
return nullptr;
}
void* r = reinterpret_cast<void*>(reinterpret_cast<size_t>(p + offset) &
(~(MALLOC_ALIGN - 1)));
static_cast<void**>(r)[-1] = p;
return r;
}
static void fast_free1(void* ptr) { DEFINE_bool(flag_relu, false, "do relu");
if (ptr) { DEFINE_bool(flag_bias, false, "with bias");
free(static_cast<void**>(ptr)[-1]);
}
}
bool test_sgemm(bool tra, bool test_sgemm(bool tra,
bool trb, bool trb,
...@@ -81,150 +64,171 @@ bool test_sgemm(bool tra, ...@@ -81,150 +64,171 @@ bool test_sgemm(bool tra,
bool has_relu, bool has_relu,
int cls, int cls,
int ths) { int ths) {
size_t size_a = tra ? k * lda : m * lda; int size_a = tra ? k * lda : m * lda;
size_t size_b = trb ? n * ldb : k * ldb; int size_b = trb ? n * ldb : k * ldb;
auto da = static_cast<float*>(fast_malloc1(size_a * sizeof(float))); Tensor ta;
auto db = static_cast<float*>(fast_malloc1(size_b * sizeof(float))); Tensor tb;
auto dc = static_cast<float*>(fast_malloc1(m * ldc * sizeof(float))); Tensor tc;
auto dc_basic = static_cast<float*>(fast_malloc1(m * ldc * sizeof(float))); Tensor tc_basic;
auto dbias = static_cast<float*>(fast_malloc1(m * sizeof(float))); Tensor tc_backup;
Tensor tbias;
fill_data_rand(da, -1.f, 1.f, size_a); ta.Resize({size_a});
fill_data_rand(db, -1.f, 1.f, size_b); tb.Resize({size_b});
fill_data_rand(dbias, -1.f, 1.f, m); tc.Resize({m * ldc});
fill_data_rand(dc, -1.f, 1.f, m * ldc); tc_basic.Resize({m * ldc});
memcpy(dc_basic, dc, sizeof(float) * m * ldc); tc_backup.Resize({m * ldc});
tbias.Resize({m});
LOG(INFO) << "sgemm M: " << m << ", N: " << n << ", K: " << k; ta.set_precision(PRECISION(kFloat));
LOG(INFO) << "strides, lda: " << lda << ", ldb: " << ldb << ", ldc: " << ldc; tb.set_precision(PRECISION(kFloat));
LOG(INFO) << "alpha: " << alpha << ", beta: " << beta; tc.set_precision(PRECISION(kFloat));
LOG(INFO) << "transA: " << (tra ? "true" : "false") tc_basic.set_precision(PRECISION(kFloat));
<< ", transB: " << (trb ? "true" : "false"); tc_backup.set_precision(PRECISION(kFloat));
LOG(INFO) << "relu: " << (has_relu ? "true" : "false") tbias.set_precision(PRECISION(kFloat));
<< ", bias: " << (has_bias ? "true" : "false");
LOG(INFO) << "basic sgemm compute"; fill_tensor_rand(ta, -1.f, 1.f);
basic_gemm(tra, fill_tensor_rand(tb, -1.f, 1.f);
trb, fill_tensor_rand(tbias, -1.f, 1.f);
m, fill_tensor_rand(tc, -1.f, 1.f);
n,
k, auto da = ta.mutable_data<float>();
alpha, auto db = tb.mutable_data<float>();
da, auto dc = tc.mutable_data<float>();
lda, auto dc_basic = tc_basic.mutable_data<float>();
db, auto dc_backup = tc_backup.mutable_data<float>();
ldb, auto dbias = tbias.mutable_data<float>();
beta,
dc_basic, memcpy(dc_basic, dc, sizeof(float) * m * ldc);
ldc, memcpy(dc_backup, dc, sizeof(float) * m * ldc);
dbias,
has_bias,
has_relu);
float max_error = 0.f; LOG(INFO) << "sgemm M: " << m << ", N: " << n << ", K: " << k
float max_ratio = 0.f; << ", strides, lda: " << lda << ", ldb: " << ldb << ", ldc: " << ldc
<< ", alpha: " << alpha << ", beta: " << beta
<< ", transA: " << (tra ? "true" : "false")
<< ", transB: " << (trb ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
if (FLAGS_check_result) {
basic_gemm(tra,
trb,
m,
n,
k,
alpha,
da,
lda,
db,
ldb,
beta,
dc_basic,
ldc,
dbias,
has_bias,
has_relu);
}
lite::test::Timer t0;
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
//! compute //! compute
LOG(INFO) << "sgemm compute";
double ops = 2.0 * m * n * k; double ops = 2.0 * m * n * k;
std::unique_ptr<paddle::lite::KernelContext> ctx1( std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext); new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>(); auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
//! prepack
Tensor tpackedA;
int hblock = paddle::lite::arm::math::get_hblock(&ctx);
int round_up_a = ((hblock + m - 1) / hblock) * hblock;
tpackedA.Resize({round_up_a * k});
paddle::lite::arm::math::prepackA(
tpackedA.mutable_data<float>(), da, alpha, lda, 0, m, 0, k, tra, &ctx);
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemm_prepack(trb,
m,
n,
k,
tpackedA.data<float>(),
db,
ldb,
beta,
dc,
ldc,
dbias,
has_bias,
has_relu,
&ctx);
}
paddle::lite::arm::math::sgemm(tra, for (int i = 0; i < FLAGS_repeats; ++i) {
trb, if (i == FLAGS_repeats - 1) {
m, memcpy(dc, dc_backup, sizeof(float) * m * ldc);
n,
k,
alpha,
da,
lda,
db,
ldb,
beta,
dc,
ldc,
dbias,
has_bias,
has_relu,
&ctx);
for (int i = 0; i < m * ldc; ++i) {
auto error = fabsf(dc[i] - dc_basic[i]);
if (error > max_error) {
max_error = error;
max_ratio = error / fabsf(dc_basic[i]);
} }
t0.start();
paddle::lite::arm::math::sgemm_prepack(trb,
m,
n,
k,
tpackedA.data<float>(),
db,
ldb,
beta,
dc,
ldc,
dbias,
has_bias,
has_relu,
&ctx);
t0.end();
} }
if (max_error > 2e-5f && max_ratio > 2e-5f) { LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k
LOG(INFO) << "max ratio: " << max_ratio << ", max_error: " << max_error; << ", cluster: " << cls << ", threads: " << ths
LOG(INFO) << "sgemm result:"; << ", GOPS: " << ops * 1e-9f
for (int i = 0; i < m * ldc; ++i) { << " GOPS, avg time: " << t0.get_average_ms()
printf("%f ", dc[i]); << " ms, min time: " << t0.get_min_time()
if ((i + 1) % ldc == 0) { << " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms()
printf("\n"); << " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time()
} << " GOPs";
}
LOG(INFO) << "basic result:"; if (FLAGS_check_result) {
for (int i = 0; i < m * ldc; ++i) { double max_ratio = 0;
printf("%f ", dc_basic[i]); double max_diff = 0;
if ((i + 1) % ldc == 0) { tensor_cmp_host(tc_basic, tc, max_ratio, max_diff);
printf("\n"); LOG(INFO) << "compare result, max diff: " << max_diff
} << ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
Tensor tdiff;
tdiff.set_precision(PRECISION(kFloat));
tdiff.Resize(tc.dims());
tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "a: ";
print_tensor(ta);
LOG(INFO) << "b: ";
print_tensor(tb);
LOG(INFO) << "c: ";
print_tensor(tc_backup);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic);
LOG(INFO) << "saber result: ";
print_tensor(tc);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
return false;
} }
} }
#endif #endif
fast_free1(da); return true;
fast_free1(db);
fast_free1(dbias);
fast_free1(dc);
fast_free1(dc_basic);
return max_error < 2e-5f || max_ratio < 2e-5f;
} }
void test_input() { TEST(TestSgemm, test_func_sgemm_prepacked) {
int lda = g_K + g_offset_a; if (FLAGS_basic_test) {
if (g_traA) { #ifdef LITE_WITH_ARM
lda = g_M + g_offset_a; paddle::lite::DeviceInfo::Init();
} #endif
int ldb = g_N + g_offset_b;
if (g_traB) {
ldb = g_K + g_offset_b;
}
int ldc = g_N + g_offset_c;
auto flag = test_sgemm(g_traA,
g_traB,
g_M,
g_N,
g_K,
lda,
ldb,
ldc,
g_alpha,
g_beta,
g_flag_bias,
g_flag_relu,
g_cluster,
g_threads);
if (!flag) {
LOG(FATAL) << "test m = " << g_M << ", n=" << g_N << ", k=" << g_K
<< ", trans A: " << g_traA << ", trans B: " << g_traB
<< ", bias: " << g_flag_bias << ", relu: " << g_flag_relu
<< " failed!!";
}
LOG(INFO) << "test m = " << g_M << ", n=" << g_N << ", k=" << g_K
<< ", trans A: " << g_traA << ", trans B: " << g_traB
<< ", bias: " << g_flag_bias << ", relu: " << g_flag_relu
<< " passed!!";
}
void test_func_sgemm_prepacked() {
if (g_basic_test) {
LOG(INFO) << "run basic sgemm test"; LOG(INFO) << "run basic sgemm test";
for (auto& m : {1, 8, 16, 111, 256, 397, 512, 777, 1024}) { for (auto& m : {1, 3, 8, 32, 397}) {
for (auto& n : {1, 3, 13, 141, 256, 345, 512, 789, 1024}) { for (auto& n : {1, 3, 13, 141, 512, 789}) {
for (auto& k : {1, 4, 15, 59, 128, 234, 512, 678, 1024}) { for (auto& k : {1, 3, 8, 59, 234}) {
for (auto& tra : {false, true}) { for (auto& tra : {false, true}) {
for (auto& trb : {false, true}) { for (auto& trb : {false, true}) {
for (auto& alpha : {1.f, 0.5f}) { for (auto& alpha : {1.f, 0.5f}) {
...@@ -254,7 +258,7 @@ void test_func_sgemm_prepacked() { ...@@ -254,7 +258,7 @@ void test_func_sgemm_prepacked() {
beta, beta,
has_bias, has_bias,
has_relu, has_relu,
g_cluster, FLAGS_cluster,
th); th);
if (flag) { if (flag) {
LOG(INFO) LOG(INFO)
...@@ -289,65 +293,41 @@ void test_func_sgemm_prepacked() { ...@@ -289,65 +293,41 @@ void test_func_sgemm_prepacked() {
} }
} }
int main(int argc, const char** argv) { TEST(TestSgemmCustom, test_func_sgemm_prepacked_custom) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init(); paddle::lite::DeviceInfo::Init();
#endif #endif
LOG(ERROR) << "usage: ./" << argv[0] int lda = FLAGS_K + FLAGS_offset_a;
<< " [do_basic_test] [cluster] [threads] [m] [n] [k] [transA] " if (FLAGS_traA) {
"[transB] [relu] [bias] [test iter] [compare result]"; lda = FLAGS_M + FLAGS_offset_a;
if (argc > 1) {
g_basic_test = atoi(argv[1]) > 0;
}
if (argc > 2) {
g_cluster = atoi(argv[2]);
}
if (argc > 3) {
g_threads = atoi(argv[3]);
}
if (argc > 4) {
if (argc < 10) {
LOG(ERROR) << "usage: ./" << argv[0] << " [do_basic_test] [cluster] "
"[threads] [m] [n] [k] "
"[transA] [transB] [bias] [relu] "
"[test iter] [compare result]";
return 0;
}
g_M = atoi(argv[4]);
g_N = atoi(argv[5]);
g_K = atoi(argv[6]);
g_traA = atoi(argv[7]) > 0;
g_traB = atoi(argv[8]) > 0;
g_flag_bias = atoi(argv[9]) > 0;
g_flag_relu = atoi(argv[10]) > 0;
} }
if (argc > 11) { int ldb = FLAGS_N + FLAGS_offset_b;
g_test_iter = atoi(argv[11]); if (FLAGS_traB) {
ldb = FLAGS_K + FLAGS_offset_b;
} }
if (argc > 12) { int ldc = FLAGS_N + FLAGS_offset_c;
g_compare_result = atoi(argv[12]) > 0; auto flag = test_sgemm(FLAGS_traA,
} FLAGS_traB,
if (argc > 13) { FLAGS_M,
g_warmup_iter = atoi(argv[13]); FLAGS_N,
} FLAGS_K,
if (argc > 14) { lda,
g_offset_a = atoi(argv[14]); ldb,
} ldc,
if (argc > 15) { FLAGS_alpha,
g_offset_b = atoi(argv[15]); FLAGS_beta,
} FLAGS_flag_bias,
if (argc > 16) { FLAGS_flag_relu,
g_offset_c = atoi(argv[16]); FLAGS_cluster,
} FLAGS_threads);
if (argc > 17) { if (!flag) {
g_alpha = atof(argv[17]); LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
} << ", k=" << FLAGS_K << ", trans A: " << FLAGS_traA
if (argc > 18) { << ", trans B: " << FLAGS_traB << ", bias: " << FLAGS_flag_bias
g_beta = atof(argv[18]); << ", relu: " << FLAGS_flag_relu << " failed!!";
}
test_input();
if (g_basic_test) {
test_func_sgemm_prepacked();
} }
return 0; LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K
<< ", trans A: " << FLAGS_traA << ", trans B: " << FLAGS_traB
<< ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu
<< " passed!!";
} }
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unistd.h>
#include <cmath>
#include <cstdlib>
#include <random>
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
template <typename Dtype>
void fill_tensor_host_const_impl(Dtype* dio, Dtype value, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = value;
}
}
/**
* \brief Fill the host tensor buffer with rand value.
* \param tensor The reference of input tensor.
*/
void fill_tensor_const(Tensor& tensor, float value) { // NOLINT
int64_t size = tensor.numel();
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
fill_tensor_host_const_impl(
tensor.mutable_data<int8_t>(), static_cast<signed char>(value), size);
break;
case PRECISION(kInt32):
fill_tensor_host_const_impl(
tensor.mutable_data<int>(), static_cast<int>(value), size);
break;
case PRECISION(kFloat):
fill_tensor_host_const_impl(
tensor.mutable_data<float>(), static_cast<float>(value), size);
break;
default:
LOG(FATAL) << "data type: " << PrecisionRepr(type)
<< " is unsupported now";
}
}
template <typename Dtype>
void fill_tensor_host_rand_impl(Dtype* dio, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
Dtype rand_x = static_cast<Dtype>(rand() % 256); // NOLINT
dio[i] = (rand_x - 128) / 128;
}
}
template <>
void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = rand() % 256 - 128; // NOLINT
}
}
template <>
void fill_tensor_host_rand_impl<unsigned char>(unsigned char* dio,
int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = rand() % 256; // NOLINT
}
}
/**
* \brief Fill the host tensor buffer with rand value.
* \param The reference of input tensor.
*/
void fill_tensor_rand(Tensor& tensor) { // NOLINT
int64_t size = tensor.numel();
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
fill_tensor_host_rand_impl(tensor.mutable_data<int8_t>(), size);
break;
case PRECISION(kInt32):
fill_tensor_host_rand_impl(tensor.mutable_data<int>(), size);
break;
case PRECISION(kFloat):
fill_tensor_host_rand_impl(tensor.mutable_data<float>(), size);
break;
default:
LOG(FATAL) << "data type: " << PrecisionRepr(type)
<< " is unsupported now";
}
}
template <typename Dtype>
void fill_tensor_host_rand_impl2(Dtype* dio,
Dtype vstart,
Dtype vend,
int64_t size) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(0, 1.f);
for (int64_t i = 0; i < size; ++i) {
Dtype random_num = static_cast<Dtype>(vstart + (vend - vstart) * dis(gen));
dio[i] = random_num;
}
}
/**
* \brief Fill the host tensor buffer with rand value from vstart to vend.
* \param tensor The reference of input tensor.
*/
void fill_tensor_rand(Tensor& tensor, float vstart, float vend) { // NOLINT
int64_t size = tensor.numel();
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
fill_tensor_host_rand_impl2(tensor.mutable_data<int8_t>(),
static_cast<signed char>(vstart),
static_cast<signed char>(vend),
size);
break;
case PRECISION(kInt32):
fill_tensor_host_rand_impl2(tensor.mutable_data<int>(),
static_cast<int>(vstart),
static_cast<int>(vend),
size);
break;
case PRECISION(kFloat):
fill_tensor_host_rand_impl2(
tensor.mutable_data<float>(), vstart, vend, size);
break;
default:
LOG(FATAL) << "data type: " << PrecisionRepr(type)
<< " is unsupported now";
}
}
template <typename Dtype>
void print_tensor_host_impl(const Dtype* din, int64_t size, int64_t width);
template <>
void print_tensor_host_impl(const float* din, int64_t size, int64_t width) {
for (int i = 0; i < size; ++i) {
printf("%.6f ", din[i]);
if ((i + 1) % width == 0) {
printf("\n");
}
}
printf("\n");
}
template <>
void print_tensor_host_impl(const int* din, int64_t size, int64_t width) {
for (int i = 0; i < size; ++i) {
printf("%d ", din[i]);
if ((i + 1) % width == 0) {
printf("\n");
}
}
printf("\n");
}
template <>
void print_tensor_host_impl(const signed char* din,
int64_t size,
int64_t width) {
for (int i = 0; i < size; ++i) {
printf("%d ", din[i]);
if ((i + 1) % width == 0) {
printf("\n");
}
}
printf("\n");
}
/**
* \brief Print the data in host tensor.
* \param tensor The reference of input tensor.
*/
void print_tensor(const Tensor& tensor) {
printf("host tensor data size: %ld\n", tensor.numel());
int64_t size = tensor.numel();
int64_t width = tensor.dims()[tensor.dims().size() - 1];
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
print_tensor_host_impl(tensor.data<int8_t>(), size, width);
break;
case PRECISION(kInt32):
print_tensor_host_impl(tensor.data<int>(), size, width);
break;
case PRECISION(kFloat):
print_tensor_host_impl(tensor.data<float>(), size, width);
break;
default:
LOG(FATAL) << "data type: " << PrecisionRepr(type)
<< " is unsupported now";
}
}
template <typename Dtype>
double tensor_mean_value_host_impl(const Dtype* din, int64_t size) {
double sum = 0.0;
for (int64_t i = 0; i < size; ++i) {
sum += din[i];
}
return sum / size;
}
double tensor_mean(const Tensor& tensor) {
int64_t size = tensor.numel();
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
return tensor_mean_value_host_impl(tensor.data<int8_t>(), size);
case PRECISION(kInt32):
return tensor_mean_value_host_impl(tensor.data<int>(), size);
case PRECISION(kFloat):
return tensor_mean_value_host_impl(tensor.data<float>(), size);
default:
LOG(FATAL) << "data type: " << PrecisionRepr(type)
<< " is unsupported now";
}
return 0.0;
}
template <typename dtype>
void data_diff_kernel(const dtype* src1_truth,
const dtype* src2,
int size,
double& max_ratio, // NOLINT
double& max_diff) { // NOLINT
const double eps = 1e-6f;
max_diff = fabs(src1_truth[0] - src2[0]);
max_ratio = fabs(max_diff) / (std::abs(src1_truth[0]) + eps);
for (int i = 1; i < size; ++i) {
double diff = fabs(src1_truth[i] - src2[i]);
double ratio = fabs(diff) / (std::abs(src1_truth[i]) + eps);
if (max_ratio < ratio) {
max_diff = diff;
max_ratio = ratio;
}
}
}
void tensor_cmp_host(const Tensor& src1_basic,
const Tensor& src2,
double& max_ratio, // NOLINT
double& max_diff) { // NOLINT
max_ratio = 0.;
max_diff = 0.;
int64_t size = src1_basic.numel();
CHECK_EQ(size, src2.numel()) << "ERROR: tensor_cmp_host: wrong shape";
auto ptype1 = PrecisionRepr(src1_basic.precision());
auto ptype2 = PrecisionRepr(src2.precision());
CHECK_EQ(ptype1, ptype2) << "ERROR: tensor_cmp_host: wrong data type";
if (size == 0) return;
switch (src1_basic.precision()) {
case PRECISION(kFloat):
data_diff_kernel(src1_basic.data<float>(),
src2.data<float>(),
size,
max_ratio,
max_diff);
return;
case PRECISION(kInt32):
data_diff_kernel(
src1_basic.data<int>(), src2.data<int>(), size, max_ratio, max_diff);
return;
case PRECISION(kInt8):
data_diff_kernel(src1_basic.data<int8_t>(),
src2.data<int8_t>(),
size,
max_ratio,
max_diff);
return;
default:
LOG(FATAL) << "data type: " << PrecisionRepr(src1_basic.precision())
<< " is unsupported now";
}
}
template <typename dtype>
void tensor_diff_kernel(const dtype* src1,
const dtype* src2,
dtype* dst,
int64_t size) {
for (int i = 0; i < size; ++i) {
dst[i] = src1[i] - src2[i];
}
}
void tensor_diff(const Tensor& t1, const Tensor& t2, Tensor& tdiff) { // NOLINT
int64_t size1 = t1.numel();
int64_t size2 = t2.numel();
int64_t size_out = tdiff.numel();
CHECK_EQ(size1, size2) << "ERROR: tensor_diff: wrong shape";
CHECK_EQ(size1, size_out) << "ERROR: tensor_diff: wrong shape";
auto ptype1 = PrecisionRepr(t1.precision());
auto ptype2 = PrecisionRepr(t2.precision());
auto ptype3 = PrecisionRepr(tdiff.precision());
CHECK_EQ(ptype1, ptype2) << "ERROR: tensor_diff: wrong data type";
CHECK_EQ(ptype1, ptype3) << "ERROR: tensor_diff: wrong data type";
switch (t1.precision()) {
case PRECISION(kFloat):
tensor_diff_kernel(t1.data<float>(),
t2.data<float>(),
tdiff.mutable_data<float>(),
size1);
return;
case PRECISION(kInt32):
tensor_diff_kernel(
t1.data<int>(), t2.data<int>(), tdiff.mutable_data<int>(), size1);
case PRECISION(kInt8):
tensor_diff_kernel(t1.data<int8_t>(),
t2.data<int8_t>(),
tdiff.mutable_data<int8_t>(),
size1);
return;
default:
LOG(FATAL) << "data type: " << ptype1 << " is unsupported now";
}
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono> // NOLINT
#include <list>
namespace lite {
namespace test {
class Timer final {
public:
Timer() {}
~Timer() {}
void clear() { ms_time_.clear(); }
void start() { tstart_ = std::chrono::system_clock::now(); }
void end() {
tend_ = std::chrono::system_clock::now();
auto ts =
std::chrono::duration_cast<std::chrono::microseconds>(tend_ - tstart_);
float elapse_ms = 1000.f * static_cast<float>(ts.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den;
ms_time_.push_back(elapse_ms);
}
float get_average_ms() {
if (ms_time_.size() == 0) {
return 0.f;
}
float sum = 0.f;
for (auto i : ms_time_) {
sum += i;
}
return sum / ms_time_.size();
}
float get_sum_ms() {
if (ms_time_.size() == 0) {
return 0.f;
}
float sum = 0.f;
for (auto i : ms_time_) {
sum += i;
}
return sum;
}
// return tile (0-99) time.
float get_tile_time(float tile) {
if (tile < 0 || tile > 100) {
return -1.f;
}
int total_items = static_cast<int>(ms_time_.size());
if (total_items <= 0) {
return -2.f;
}
ms_time_.sort();
int pos = static_cast<int>(tile * total_items / 100);
auto it = ms_time_.begin();
for (int i = 0; i < pos; ++i) {
++it;
}
return *it;
}
std::list<float> get_time_stat() { return ms_time_; }
float get_min_time() {
ms_time_.sort();
return *ms_time_.begin();
}
float get_max_time() {
ms_time_.sort([](int a, int b) { return a > b; });
return *ms_time_.begin();
}
private:
std::chrono::time_point<std::chrono::system_clock> tstart_;
std::chrono::time_point<std::chrono::system_clock> tend_;
std::list<float> ms_time_;
};
} // namespace test
} // namespace lite
...@@ -15,6 +15,7 @@ readonly NUM_PROC=${LITE_BUILD_THREADS:-4} ...@@ -15,6 +15,7 @@ readonly NUM_PROC=${LITE_BUILD_THREADS:-4}
# global variables # global variables
BUILD_EXTRA=OFF BUILD_EXTRA=OFF
BUILD_JAVA=ON BUILD_JAVA=ON
BUILD_DIR=$(pwd)
readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/third-party-05b862.tar.gz readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/third-party-05b862.tar.gz
...@@ -23,16 +24,18 @@ readonly workspace=$PWD ...@@ -23,16 +24,18 @@ readonly workspace=$PWD
# for code gen, a source file is generated after a test, but is dependended by some targets in cmake. # for code gen, a source file is generated after a test, but is dependended by some targets in cmake.
# here we fake an empty file to make cmake works. # here we fake an empty file to make cmake works.
function prepare_workspace { function prepare_workspace {
local root_dir=$1
local build_dir=$2
# in build directory # in build directory
# 1. Prepare gen_code file # 1. Prepare gen_code file
GEN_CODE_PATH_PREFIX=lite/gen_code GEN_CODE_PATH_PREFIX=$build_dir/lite/gen_code
mkdir -p ./${GEN_CODE_PATH_PREFIX} mkdir -p ${GEN_CODE_PATH_PREFIX}
touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc touch ${GEN_CODE_PATH_PREFIX}/__generated_code__.cc
# 2.Prepare debug tool # 2.Prepare debug tool
DEBUG_TOOL_PATH_PREFIX=lite/tools/debug DEBUG_TOOL_PATH_PREFIX=$build_dir/lite/tools/debug
mkdir -p ./${DEBUG_TOOL_PATH_PREFIX} mkdir -p ${DEBUG_TOOL_PATH_PREFIX}
cp ../${DEBUG_TOOL_PATH_PREFIX}/analysis_tool.py ./${DEBUG_TOOL_PATH_PREFIX}/ cp $root_dir/lite/tools/debug/analysis_tool.py ${DEBUG_TOOL_PATH_PREFIX}/
} }
function prepare_thirdparty { function prepare_thirdparty {
...@@ -98,21 +101,22 @@ function make_full_publish_so { ...@@ -98,21 +101,22 @@ function make_full_publish_so {
#git submodule update --init --recursive #git submodule update --init --recursive
prepare_thirdparty prepare_thirdparty
cur_dir=$(pwd) root_dir=$(pwd)
build_dir=$cur_dir/build.lite.${os}.${abi}.${lang} build_directory=$BUILD_DIR/build.lite.${os}.${abi}.${lang}
if [ -d $build_dir ]
if [ -d $build_directory ]
then then
rm -rf $build_dir rm -rf $build_directory
fi fi
mkdir -p $build_dir mkdir -p $build_directory
cd $build_dir cd $build_directory
if [ ${os} == "armlinux" ]; then if [ ${os} == "armlinux" ]; then
BUILD_JAVA=OFF BUILD_JAVA=OFF
fi fi
prepare_workspace prepare_workspace $root_dir $build_directory
cmake .. \ cmake $root_dir \
${CMAKE_COMMON_OPTIONS} \ ${CMAKE_COMMON_OPTIONS} \
-DWITH_TESTING=OFF \ -DWITH_TESTING=OFF \
-DLITE_WITH_JAVA=$BUILD_JAVA \ -DLITE_WITH_JAVA=$BUILD_JAVA \
...@@ -132,23 +136,23 @@ function make_all_tests { ...@@ -132,23 +136,23 @@ function make_all_tests {
#git submodule update --init --recursive #git submodule update --init --recursive
prepare_thirdparty prepare_thirdparty
cur_dir=$(pwd) root_dir=$(pwd)
build_dir=$cur_dir/build.lite.${os}.${abi}.${lang} build_directory=$BUILD_DIR/build.lite.${os}.${abi}.${lang}
if [ -d $build_dir ] if [ -d $build_dir ]
then then
rm -rf $build_dir rm -rf $build_dir
fi fi
mkdir -p $build_dir mkdir -p $build_directory
cd $build_dir cd $build_directory
prepare_workspace prepare_workspace $root_dir $build_directory
cmake .. \ cmake $root_dir \
${CMAKE_COMMON_OPTIONS} \ ${CMAKE_COMMON_OPTIONS} \
-DWITH_TESTING=ON \ -DWITH_TESTING=ON \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \ -DLITE_BUILD_EXTRA=$BUILD_EXTRA \
-DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang}
make lite_compile_deps -j4 make lite_compile_deps -j$NUM_PROC
cd - > /dev/null cd - > /dev/null
} }
...@@ -207,6 +211,7 @@ function print_usage { ...@@ -207,6 +211,7 @@ function print_usage {
echo echo
echo -e "optional argument:" echo -e "optional argument:"
echo -e "--build_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP)" echo -e "--build_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP)"
echo -e "--build_dir: directory for building"
echo echo
echo -e "argument choices:" echo -e "argument choices:"
echo -e "--arm_os:\t android|ios|ios64" echo -e "--arm_os:\t android|ios|ios64"
...@@ -252,6 +257,10 @@ function main { ...@@ -252,6 +257,10 @@ function main {
BUILD_EXTRA="${i#*=}" BUILD_EXTRA="${i#*=}"
shift shift
;; ;;
--build_dir=*)
BUILD_DIR="${i#*=}"
shift
;;
tiny_publish) tiny_publish)
make_tiny_publish_so $ARM_OS $ARM_ABI $ARM_LANG $ANDROID_STL make_tiny_publish_so $ARM_OS $ARM_ABI $ARM_LANG $ANDROID_STL
shift shift
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
import os
import re
def compute_sdot_vec_vec(vd, vn, vm):
i = 0x4e809400 | int(vd) | (int(vn) << 5) | (int(vm) << 16)
return '".word 0x{:08x}\\n"'.format(i) + \
' /* sdot v{vd}.4s, v{vn}.16b, v{vm}.16b */'.format(
vd=vd, vn=vn, vm=vm)
def compute_sdot_vec_elem(vd, vn, vm, idx):
i = 0x4f80e000 | int(vd) | (int(vn) << 5) | (int(vm) << 16) | (int(idx % 2) << 21) | (int(idx / 2) << 11)
return '".word 0x{:08x}\\n"'.format(i) + \
' /* sdot v{vd}.4s, v{vn}.16b, v{vm}.4b[{idx}] */\\\r\n'.format(
vd=vd, vn=vn, vm=vm, idx=idx)
def match_sdot_patten(line):
matched = re.search(r'sdot\s+v(.*?).4s\s*,\s*v(.*?).16b\s*,\s*v(.*?).4b\[(.*?)\].*', line, re.M|re.I)
if matched:
# print('matched:', matched.group(1), matched.group(2), matched.group(3), matched.group(4))
vd = int(matched.group(1))
vn = int(matched.group(2))
vm = int(matched.group(3))
idx = int(matched.group(4))
return compute_sdot_vec_elem(vd, vn, vm, idx)
else:
return line
def parser_file(file_in, file_out):
out = open(file_out, 'w')
if os.path.exists(file_in):
for line in open(file_in):
new_line = match_sdot_patten(line)
# print(new_line)
out.write(new_line)
else:
print('input file {} not exist'.format(file_in))
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser('convert arm sdot to machine code')
arg_parser.add_argument('--input_file', type=str, required=True)
arg_parser.add_argument('--output_file', type=str, required=True)
args = arg_parser.parse_args()
print('input file: ', args.input_file)
print('output file: ', args.output_file)
parser_file(args.input_file, args.output_file)
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
*/ */
#include "lite/utils/logging.h" #include "lite/utils/logging.h"
#include <iomanip>
#if defined(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || \ #if defined(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || \
defined(LITE_ON_MODEL_OPTIMIZE_TOOL) defined(LITE_ON_MODEL_OPTIMIZE_TOOL)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册