提交 f99c34c8 编写于 作者: T TianXiaogang 提交者: yiicy

add winograd f23 implement (#2584)

上级 fbb0d3b5
......@@ -24,29 +24,48 @@ namespace paddle {
namespace lite {
namespace arm {
namespace math {
void input_trans_c4(const float* src,
int src_stride,
float* dest,
int dest_stride);
void output_trans_c4(const float* src,
int src_stride,
float* dest,
int dest_stride);
void output_trans_c4_post(const float* src,
int src_stride,
float* dest,
int dest_stride,
float* bias_value,
bool has_relu);
void weight_trans_c4(
void input_trans_c4_8x8(const float* src,
int src_stride,
float* dest,
int dest_stride);
void output_trans_c4_6x8(const float* src,
int src_stride,
float* dest,
int dest_stride);
void output_trans_c4_post_6x8(const float* src,
int src_stride,
float* dest,
int dest_stride,
float* bias_value,
bool has_relu);
void input_trans_c4_4x4(const float* src,
int src_stride,
int src_h_stride,
float* dest,
int dest_stride,
int dest_h_stride);
void output_trans_c4_post_2x4(const float* src,
int src_stride,
int src_h_stride,
float* dest,
int dest_stride,
int dest_h_stride,
float* bias_value,
bool has_relu);
void weight_trans_c4_8x8(
float* dest, const float* src, int ic, int oc, void* workspace);
void weight_trans_c4_4x4(
float* dest, const float* src, int ic, int oc, void* workspace);
/*
*The following function conv_compute_6x6_3x3 is base on
*The following function conv_compute_6x6_3x3 and conv_compute_2x2_3x3[_small] is
*base on
*MNN[https://github.com/alibaba/MNN]
*
*Copyright © 2018, Alibaba Group Holding Limited
*/
// F(6,3)
void conv_compute_6x6_3x3(const float* input,
float* output,
int num,
......@@ -75,11 +94,14 @@ void conv_compute_6x6_3x3(const float* input,
int tile_w = (wout + 5) / 6;
int tile_h = (hout + 5) / 6;
int size_tile = tile_h * tile_w;
float zero_ptr[8];
memset(zero_ptr, 0, 8 * sizeof(float));
int w_pad = win + pad_w * 2;
int h_pad = hin + pad_h * 2;
const int zero_len = w_pad;
float zero_ptr[zero_len]; // NOLINT
memset(zero_ptr, 0, zero_len * sizeof(float));
float* input_c4 = tmp_work_space;
int new_h_stride = w_pad * 4;
int new_c_stride = new_h_stride * h_pad;
......@@ -88,9 +110,6 @@ void conv_compute_6x6_3x3(const float* input,
int oc_4_stride = wout * hout * 4;
int tile_block = 8;
#ifdef __aarch64__
tile_block = 16;
#endif
int block_count = (size_tile + tile_block - 1) / tile_block;
int threads = ctx->threads();
......@@ -102,7 +121,8 @@ void conv_compute_6x6_3x3(const float* input,
// begin compute
for (int ni = 0; ni < num; ++ni) {
// trans input to c4
// trans input to c4
#pragma omp parallel for num_threads(threads)
for (int i = 0; i < ic_4; ++i) {
prepack_input_nxwc4_dw(input + ni * in_n_stride,
input_c4 + i * new_c_stride,
......@@ -161,14 +181,14 @@ void conv_compute_6x6_3x3(const float* input,
const float* src_ci = src_ptr + ci * ic_4_stride;
for (int i = 0; i < 8; ++i) {
const float* ci_ptr = src_ci + i * w_pad * 4;
input_trans_c4(ci_ptr, 4, trans_tmp_data + i * 4, 32);
input_trans_c4_8x8(ci_ptr, 4, trans_tmp_data + i * 4, 32);
}
float* dst_ci = dst_ptr + ci * tile_count * 4;
for (int i = 0; i < 8; ++i) {
input_trans_c4(trans_tmp_data + i * 32,
4,
dst_ci + i * b_gi_stride * 8,
b_gi_stride);
input_trans_c4_8x8(trans_tmp_data + i * 32,
4,
dst_ci + i * b_gi_stride * 8,
b_gi_stride);
}
}
} else {
......@@ -189,14 +209,14 @@ void conv_compute_6x6_3x3(const float* input,
// trans
for (int i = 0; i < 8; ++i) {
float* ci_ptr = trans_remain_tmp_data + i * 32;
input_trans_c4(ci_ptr, 4, trans_tmp_data + i * 4, 32);
input_trans_c4_8x8(ci_ptr, 4, trans_tmp_data + i * 4, 32);
}
float* dst_ci = dst_ptr + ci * tile_count * 4;
for (int i = 0; i < 8; ++i) {
input_trans_c4(trans_tmp_data + i * 32,
4,
dst_ci + i * b_gi_stride * 8,
b_gi_stride);
input_trans_c4_8x8(trans_tmp_data + i * 32,
4,
dst_ci + i * b_gi_stride * 8,
b_gi_stride);
}
} // for ci_4
}
......@@ -213,16 +233,8 @@ void conv_compute_6x6_3x3(const float* input,
float* origin_C = dst_temp_data + gi * c_gi_stride;
float* origin_B = b_ptr + gi * b_gi_stride;
const float* origin_A = weight + gi * w_gi_stride;
sgemm_prepack_c4_small(oc_4 * 4,
tile_count,
ic_4 * 4,
origin_A,
origin_B,
origin_C,
nullptr,
false,
false,
ctx);
sgemm_prepack_c4_small(
oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx);
}
//*/
//*
......@@ -258,18 +270,18 @@ void conv_compute_6x6_3x3(const float* input,
float* dst_ci = dst_ptr + ci * oc_4_stride;
float* src_ci = src_ptr + ci * tile_count * 4;
for (int i = 0; i < 8; ++i) {
output_trans_c4(src_ci + i * c_gi_stride * 8,
c_gi_stride,
trans_tmp_data + i * 4,
32);
output_trans_c4_6x8(src_ci + i * c_gi_stride * 8,
c_gi_stride,
trans_tmp_data + i * 4,
32);
}
for (int i = 0; i < ey; ++i) {
output_trans_c4_post(trans_tmp_data + i * 32,
4,
trans_remain_tmp_data + i * 24,
4,
bias_value,
param.fuse_relu);
output_trans_c4_post_6x8(trans_tmp_data + i * 32,
4,
trans_remain_tmp_data + i * 24,
4,
bias_value,
param.fuse_relu);
}
write_to_output_c4_fp32(trans_remain_tmp_data,
output_ptr,
......@@ -297,18 +309,18 @@ void conv_compute_6x6_3x3(const float* input,
float* dst_ci = dst_ptr + ci * oc_4_stride;
float* src_ci = src_ptr + ci * tile_count * 4;
for (int i = 0; i < 8; ++i) {
output_trans_c4(src_ci + i * c_gi_stride * 8,
c_gi_stride,
trans_tmp_data + i * 4,
32);
output_trans_c4_6x8(src_ci + i * c_gi_stride * 8,
c_gi_stride,
trans_tmp_data + i * 4,
32);
}
for (int i = 0; i < ey; ++i) {
output_trans_c4_post(trans_tmp_data + i * 32,
4,
trans_remain_tmp_data + i * 24,
4,
bias_value,
param.fuse_relu);
output_trans_c4_post_6x8(trans_tmp_data + i * 32,
4,
trans_remain_tmp_data + i * 24,
4,
bias_value,
param.fuse_relu);
}
// copy to dest
memset(trans_tmp_data, 0, 144 * sizeof(float));
......@@ -338,10 +350,522 @@ void conv_compute_6x6_3x3(const float* input,
} // for num
} // conv_compute
void output_trans_c4(const float* src,
int src_stride,
float* dest,
int dest_stride) {
// F(2,3)
void conv_compute_2x2_3x3(const float* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weight,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx) {
const int pad_h = (*param.paddings)[0];
const int pad_w = (*param.paddings)[2];
float* tmp_work_space =
ctx->workspace_data<float>() + ctx->llc_size() / sizeof(float);
int in_n_stride = chin * hin * win;
int out_n_stride = chout * hout * wout;
int ic_stride = win * hin;
int oc_stride = wout * hout;
int ic_4 = (chin + 3) / 4;
int oc_4 = (chout + 3) / 4;
int tile_w = (wout + 1) / 2;
int tile_h = (hout + 1) / 2;
int size_tile = tile_h * tile_w;
int w_pad = win + pad_w * 2;
int h_pad = hin + pad_h * 2;
const int zero_len = w_pad;
float zero_ptr[zero_len]; // NOLINT
memset(zero_ptr, 0, zero_len * sizeof(float));
float* input_c4 = tmp_work_space;
int new_h_stride = w_pad * 4;
int new_c_stride = new_h_stride * h_pad;
int ic_4_stride = w_pad * h_pad * 4;
int oc_4_stride = wout * hout * 4;
int tile_block = 8;
int block_count = (size_tile + tile_block - 1) / tile_block;
int threads = ctx->threads();
float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride;
int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 64;
memset(g_tmp_data, 0, threads * tmp_data_thread_stride * sizeof(float));
float* g_trans_tmp_data = g_tmp_data + threads * tmp_data_thread_stride;
float* g_trans_remain_tmp_data = g_trans_tmp_data + threads * 64;
// begin compute
for (int ni = 0; ni < num; ++ni) {
// trans input to c4
#pragma omp parallel for num_threads(threads)
for (int i = 0; i < ic_4; ++i) {
prepack_input_nxwc4_dw(input + ni * in_n_stride,
input_c4 + i * new_c_stride,
i * 4,
-pad_h,
hin + pad_h,
-pad_w,
win + pad_w,
chin,
win,
hin,
zero_ptr);
}
float* output_ptr = output + ni * out_n_stride;
const float* weight_ptr = weight;
const float* bias_ptr = bias;
#pragma omp parallel for num_threads(threads)
for (int tbi = 0; tbi < block_count; ++tbi) {
#ifdef ARM_WITH_OMP
float* tmp_data =
g_tmp_data + omp_get_thread_num() * tmp_data_thread_stride;
float* trans_tmp_data = g_trans_tmp_data + omp_get_thread_num() * 64;
float* trans_remain_tmp_data =
g_trans_remain_tmp_data + omp_get_thread_num() * 64;
#else
float* tmp_data = g_tmp_data;
float* trans_tmp_data = g_trans_tmp_data;
float* trans_remain_tmp_data = g_trans_remain_tmp_data;
#endif
int tile_index = tbi * tile_block;
int tile_remain = size_tile - tile_index;
int tile_count = tile_remain > tile_block ? tile_block : tile_remain;
// input trans
int c_gi_stride = tile_count * oc_4 * 4;
int b_gi_stride = tile_count * ic_4 * 4;
//*
for (int ti = 0; ti < tile_count; ++ti) {
int index = tile_index + ti;
int tw_index = index % tile_w;
int th_index = index / tile_w;
int src_x = tw_index + tw_index;
int src_y = th_index + th_index;
int ex = src_x + 4 > w_pad ? w_pad - src_x : 4;
int ey = src_y + 4 > h_pad ? h_pad - src_y : 4;
float* dst_ptr = tmp_data + ti * 4;
const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4;
if (ex == 4 && ey == 4) {
// trans input
for (int ci = 0; ci < ic_4; ++ci) {
const float* src_ci = src_ptr + ci * ic_4_stride;
float* dst_ci = dst_ptr + ci * tile_count * 4;
input_trans_c4_4x4(
src_ci, 4, w_pad * 4, dst_ci, b_gi_stride, b_gi_stride * 4);
}
} else {
// trans remain input
int x_size = ex;
for (int ci = 0; ci < ic_4; ++ci) {
const float* src_ci = src_ptr + ci * ic_4_stride;
// pad
memset(trans_remain_tmp_data, 0, 64 * sizeof(float));
if (x_size > 0) {
for (int yi = 0; yi < ey; ++yi) {
float* dst_yi = trans_remain_tmp_data + yi * 16;
const float* src_yi = src_ci + w_pad * yi * 4;
memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4);
}
}
// trans
float* dst_ci = dst_ptr + ci * tile_count * 4;
input_trans_c4_4x4(trans_remain_tmp_data,
4,
16,
dst_ci,
b_gi_stride,
b_gi_stride * 4);
} // for ci_4
}
}
//*/
// input trans end
// *begin compute dot
// *
//*
float* dst_temp_data = tmp_data + tile_block * ic_4 * 64;
float* b_ptr = tmp_data;
int w_gi_stride = ic_4 * oc_4 * 16;
for (int gi = 0; gi < 16; ++gi) {
float* origin_C = dst_temp_data + gi * c_gi_stride;
float* origin_B = b_ptr + gi * b_gi_stride;
const float* origin_A = weight + gi * w_gi_stride;
sgemm_prepack_c4_small(
oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx);
}
//*/
//*
// output trans
float bias_value[4];
memset(bias_value, 0, 4 * sizeof(float));
for (int ti = 0; ti < tile_count; ++ti) {
int index = tile_index + ti;
int tw_index = index % tile_w;
int th_index = index / tile_w;
int dst_x = tw_index * 2;
int dst_y = th_index * 2;
int ex = dst_x + 2 > wout ? wout - dst_x : 2;
int ey = dst_y + 2 > hout ? hout - dst_y : 2;
float* dst_ptr = output + (dst_y * wout + dst_x) * 4;
float* src_ptr = dst_temp_data + ti * 4;
if (ex == 2) {
// trans output
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
}
float* dst_ci = dst_ptr + ci * oc_4_stride;
float* src_ci = src_ptr + ci * tile_count * 4;
output_trans_c4_post_2x4(src_ci,
c_gi_stride,
c_gi_stride * 4,
trans_remain_tmp_data,
4,
8,
bias_value,
param.fuse_relu);
write_to_output_c4_fp32(trans_remain_tmp_data,
output_ptr,
ci * 4,
ci * 4 + 4,
dst_y,
dst_y + ey,
dst_x,
dst_x + ex,
chout,
hout,
wout,
false,
zero_ptr);
}
} else {
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
}
// trans output
float* dst_ci = dst_ptr + ci * oc_4_stride;
float* src_ci = src_ptr + ci * tile_count * 4;
output_trans_c4_post_2x4(src_ci,
c_gi_stride,
c_gi_stride * 4,
trans_remain_tmp_data,
4,
8,
bias_value,
param.fuse_relu);
// copy to dest
memset(trans_tmp_data, 0, 16 * sizeof(float));
for (int i = 0; i < ey; ++i) {
memcpy(trans_tmp_data + i * ex * 4,
trans_remain_tmp_data + i * 8,
ex * sizeof(float) * 4);
}
write_to_output_c4_fp32(trans_tmp_data,
output_ptr,
ci * 4,
ci * 4 + 4,
dst_y,
dst_y + ey,
dst_x,
dst_x + ex,
chout,
hout,
wout,
false,
zero_ptr);
}
}
}
//*/
} // for block_count
} // for num
} // conv_compute
void conv_compute_2x2_3x3_small(const float* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weight,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx) {
const int pad_h = (*param.paddings)[0];
const int pad_w = (*param.paddings)[2];
float* tmp_work_space =
ctx->workspace_data<float>() + ctx->llc_size() / sizeof(float);
int in_n_stride = chin * hin * win;
int out_n_stride = chout * hout * wout;
int ic_stride = win * hin;
int oc_stride = wout * hout;
int ic_4 = (chin + 3) / 4;
int oc_4 = (chout + 3) / 4;
int tile_w = (wout + 1) / 2;
int tile_h = (hout + 1) / 2;
int size_tile = tile_h * tile_w;
int w_pad = win + pad_w * 2;
int h_pad = hin + pad_h * 2;
const int zero_len = w_pad;
float zero_ptr[zero_len]; // NOLINT
memset(zero_ptr, 0, zero_len * sizeof(float));
float* input_c4 = tmp_work_space;
int new_h_stride = w_pad * 4;
int new_c_stride = new_h_stride * h_pad;
int ic_4_stride = w_pad * h_pad * 4;
int oc_4_stride = wout * hout * 4;
int tile_block = 8;
int block_count = (size_tile + tile_block - 1) / tile_block;
int threads = ctx->threads();
float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride;
int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 64;
memset(g_tmp_data, 0, tmp_data_thread_stride * sizeof(float));
float* g_trans_tmp_data = g_tmp_data + tmp_data_thread_stride;
float* g_trans_remain_tmp_data = g_trans_tmp_data + 64;
// begin compute
for (int ni = 0; ni < num; ++ni) {
// trans input to c4
#pragma omp parallel for num_threads(threads)
for (int i = 0; i < ic_4; ++i) {
prepack_input_nxwc4_dw(input + ni * in_n_stride,
input_c4 + i * new_c_stride,
i * 4,
-pad_h,
hin + pad_h,
-pad_w,
win + pad_w,
chin,
win,
hin,
zero_ptr);
}
float* output_ptr = output + ni * out_n_stride;
const float* weight_ptr = weight;
const float* bias_ptr = bias;
for (int tbi = 0; tbi < block_count; ++tbi) {
float* tmp_data = g_tmp_data;
float* trans_tmp_data = g_trans_tmp_data;
float* trans_remain_tmp_data = g_trans_remain_tmp_data;
int tile_index = tbi * tile_block;
int tile_remain = size_tile - tile_index;
int tile_count = tile_remain > tile_block ? tile_block : tile_remain;
// input trans
int c_gi_stride = tile_count * oc_4 * 4;
int b_gi_stride = tile_count * ic_4 * 4;
//*
for (int ti = 0; ti < tile_count; ++ti) {
int index = tile_index + ti;
int tw_index = index % tile_w;
int th_index = index / tile_w;
int src_x = tw_index + tw_index;
int src_y = th_index + th_index;
int ex = src_x + 4 > w_pad ? w_pad - src_x : 4;
int ey = src_y + 4 > h_pad ? h_pad - src_y : 4;
float* dst_ptr = tmp_data + ti * 4;
const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4;
if (ex == 4 && ey == 4) {
// trans input
for (int ci = 0; ci < ic_4; ++ci) {
const float* src_ci = src_ptr + ci * ic_4_stride;
float* dst_ci = dst_ptr + ci * tile_count * 4;
input_trans_c4_4x4(
src_ci, 4, w_pad * 4, dst_ci, b_gi_stride, b_gi_stride * 4);
}
} else {
// trans remain input
int x_size = ex;
for (int ci = 0; ci < ic_4; ++ci) {
const float* src_ci = src_ptr + ci * ic_4_stride;
// pad
memset(trans_remain_tmp_data, 0, 64 * sizeof(float));
if (x_size > 0) {
for (int yi = 0; yi < ey; ++yi) {
float* dst_yi = trans_remain_tmp_data + yi * 16;
const float* src_yi = src_ci + w_pad * yi * 4;
memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4);
}
}
float* dst_ci = dst_ptr + ci * tile_count * 4;
input_trans_c4_4x4(trans_remain_tmp_data,
4,
16,
dst_ci,
b_gi_stride,
b_gi_stride * 4);
} // for ci_4
}
}
//*/
// input trans end
// *begin compute dot
// *
//*
float* dst_temp_data = tmp_data + tile_block * ic_4 * 64;
float* b_ptr = tmp_data;
int w_gi_stride = ic_4 * oc_4 * 16;
#pragma omp parallel for num_threads(threads)
for (int gi = 0; gi < 16; ++gi) {
float* origin_C = dst_temp_data + gi * c_gi_stride;
float* origin_B = b_ptr + gi * b_gi_stride;
const float* origin_A = weight + gi * w_gi_stride;
sgemm_prepack_c4_small(
oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx);
}
//*/
//*
// output trans
float bias_value[4];
memset(bias_value, 0, 4 * sizeof(float));
for (int ti = 0; ti < tile_count; ++ti) {
int index = tile_index + ti;
int tw_index = index % tile_w;
int th_index = index / tile_w;
int dst_x = tw_index * 2;
int dst_y = th_index * 2;
int ex = dst_x + 2 > wout ? wout - dst_x : 2;
int ey = dst_y + 2 > hout ? hout - dst_y : 2;
float* dst_ptr = output + (dst_y * wout + dst_x) * 4;
float* src_ptr = dst_temp_data + ti * 4;
if (ex == 2) {
// trans output
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
}
float* dst_ci = dst_ptr + ci * oc_4_stride;
float* src_ci = src_ptr + ci * tile_count * 4;
output_trans_c4_post_2x4(src_ci,
c_gi_stride,
c_gi_stride * 4,
trans_remain_tmp_data,
4,
8,
bias_value,
param.fuse_relu);
write_to_output_c4_fp32(trans_remain_tmp_data,
output_ptr,
ci * 4,
ci * 4 + 4,
dst_y,
dst_y + ey,
dst_x,
dst_x + ex,
chout,
hout,
wout,
false,
zero_ptr);
}
} else {
for (int ci = 0; ci < oc_4; ++ci) {
if (param.bias) {
bias_value[0] = bias[ci * 4];
bias_value[1] = bias[ci * 4 + 1];
bias_value[2] = bias[ci * 4 + 2];
bias_value[3] = bias[ci * 4 + 3];
}
// trans output
float* dst_ci = dst_ptr + ci * oc_4_stride;
float* src_ci = src_ptr + ci * tile_count * 4;
output_trans_c4_post_2x4(src_ci,
c_gi_stride,
c_gi_stride * 4,
trans_remain_tmp_data,
4,
8,
bias_value,
param.fuse_relu);
// copy to dest
memset(trans_tmp_data, 0, 16 * sizeof(float));
for (int i = 0; i < ey; ++i) {
memcpy(trans_tmp_data + i * ex * 4,
trans_remain_tmp_data + i * 8,
ex * sizeof(float) * 4);
}
write_to_output_c4_fp32(trans_tmp_data,
output_ptr,
ci * 4,
ci * 4 + 4,
dst_y,
dst_y + ey,
dst_x,
dst_x + ex,
chout,
hout,
wout,
false,
zero_ptr);
}
}
}
//*/
} // for block_count
} // for num
} // conv_compute
void output_trans_c4_6x8(const float* src,
int src_stride,
float* dest,
int dest_stride) {
const float32x4_t src0 = vld1q_f32(src);
const float32x4_t src1 = vld1q_f32(src + src_stride);
const float32x4_t src2 = vld1q_f32(src + src_stride * 2);
......@@ -381,12 +905,13 @@ void output_trans_c4(const float* src,
vst1q_f32(dest + dest_stride * 4, dest4);
vst1q_f32(dest + dest_stride * 5, dest5);
}
void output_trans_c4_post(const float* src,
int src_stride,
float* dest,
int dest_stride,
float* bias_value,
bool has_relu = false) {
void output_trans_c4_post_6x8(const float* src,
int src_stride,
float* dest,
int dest_stride,
float* bias_value,
bool has_relu = false) {
const float32x4_t src0 = vld1q_f32(src);
const float32x4_t src1 = vld1q_f32(src + src_stride);
const float32x4_t src2 = vld1q_f32(src + src_stride * 2);
......@@ -447,10 +972,10 @@ void output_trans_c4_post(const float* src,
vst1q_f32(dest + dest_stride * 5, dest5);
}
void input_trans_c4(const float* src,
int src_stride,
float* dest,
int dest_stride) {
void input_trans_c4_8x8(const float* src,
int src_stride,
float* dest,
int dest_stride) {
float32x4_t src0 = vld1q_f32(src);
float32x4_t src1 = vld1q_f32(src + src_stride);
float32x4_t src2 = vld1q_f32(src + src_stride * 2);
......@@ -497,7 +1022,165 @@ void input_trans_c4(const float* src,
vst1q_f32(dest + dest_stride * 6, dst6);
vst1q_f32(dest + dest_stride * 7, dst7);
}
void weight_trans_c4(
// BT=[1, 0, -1, 0,
// 0, 1, 1, 0,
// 0, -1, 1, 0,
// 0, 1, 0, -1]
void input_trans_c4_4x4(const float* src,
int src_stride,
int src_h_stride,
float* dest,
int dest_stride,
int dest_h_stride) {
float32x4_t src00 = vld1q_f32(src);
float32x4_t src01 = vld1q_f32(src + src_stride);
float32x4_t src02 = vld1q_f32(src + src_stride + src_stride);
float32x4_t src03 = vld1q_f32(src + src_stride + src_stride + src_stride);
src += src_h_stride;
float32x4_t src10 = vld1q_f32(src);
float32x4_t src11 = vld1q_f32(src + src_stride);
float32x4_t src12 = vld1q_f32(src + src_stride + src_stride);
float32x4_t src13 = vld1q_f32(src + src_stride + src_stride + src_stride);
src += src_h_stride;
float32x4_t src20 = vld1q_f32(src);
float32x4_t src21 = vld1q_f32(src + src_stride);
float32x4_t src22 = vld1q_f32(src + src_stride + src_stride);
float32x4_t src23 = vld1q_f32(src + src_stride + src_stride + src_stride);
src += src_h_stride;
float32x4_t src30 = vld1q_f32(src);
float32x4_t src31 = vld1q_f32(src + src_stride);
float32x4_t src32 = vld1q_f32(src + src_stride + src_stride);
float32x4_t src33 = vld1q_f32(src + src_stride + src_stride + src_stride);
float32x4_t dst00 = vsubq_f32(src00, src02);
float32x4_t dst10 = vaddq_f32(src01, src02);
float32x4_t dst20 = vsubq_f32(src02, src01);
float32x4_t dst30 = vsubq_f32(src01, src03);
float32x4_t dst01 = vsubq_f32(src10, src12);
float32x4_t dst11 = vaddq_f32(src11, src12);
float32x4_t dst21 = vsubq_f32(src12, src11);
float32x4_t dst31 = vsubq_f32(src11, src13);
float32x4_t dst02 = vsubq_f32(src20, src22);
float32x4_t dst12 = vaddq_f32(src21, src22);
float32x4_t dst22 = vsubq_f32(src22, src21);
float32x4_t dst32 = vsubq_f32(src21, src23);
float32x4_t dst03 = vsubq_f32(src30, src32);
float32x4_t dst13 = vaddq_f32(src31, src32);
float32x4_t dst23 = vsubq_f32(src32, src31);
float32x4_t dst33 = vsubq_f32(src31, src33);
float32x4_t dest00 = vsubq_f32(dst00, dst02);
float32x4_t dest10 = vaddq_f32(dst01, dst02);
float32x4_t dest20 = vsubq_f32(dst02, dst01);
float32x4_t dest30 = vsubq_f32(dst01, dst03);
float32x4_t dest01 = vsubq_f32(dst10, dst12);
float32x4_t dest11 = vaddq_f32(dst11, dst12);
float32x4_t dest21 = vsubq_f32(dst12, dst11);
float32x4_t dest31 = vsubq_f32(dst11, dst13);
float32x4_t dest02 = vsubq_f32(dst20, dst22);
float32x4_t dest12 = vaddq_f32(dst21, dst22);
float32x4_t dest22 = vsubq_f32(dst22, dst21);
float32x4_t dest32 = vsubq_f32(dst21, dst23);
float32x4_t dest03 = vsubq_f32(dst30, dst32);
float32x4_t dest13 = vaddq_f32(dst31, dst32);
float32x4_t dest23 = vsubq_f32(dst32, dst31);
float32x4_t dest33 = vsubq_f32(dst31, dst33);
vst1q_f32(dest, dest00);
vst1q_f32(dest + dest_stride, dest10);
vst1q_f32(dest + dest_stride + dest_stride, dest20);
vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest30);
dest += dest_h_stride;
vst1q_f32(dest, dest01);
vst1q_f32(dest + dest_stride, dest11);
vst1q_f32(dest + dest_stride + dest_stride, dest21);
vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest31);
dest += dest_h_stride;
vst1q_f32(dest, dest02);
vst1q_f32(dest + dest_stride, dest12);
vst1q_f32(dest + dest_stride + dest_stride, dest22);
vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest32);
dest += dest_h_stride;
vst1q_f32(dest, dest03);
vst1q_f32(dest + dest_stride, dest13);
vst1q_f32(dest + dest_stride + dest_stride, dest23);
vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest33);
}
// AT=[1, 1, 1, 0,
// 0, 1, -1, -1]
void output_trans_c4_post_2x4(const float* src,
int src_stride,
int src_h_stride,
float* dest,
int dest_stride,
int dest_h_stride,
float* bias_value,
bool has_relu) {
float32x4_t src00 = vld1q_f32(src);
float32x4_t src01 = vld1q_f32(src + src_stride);
float32x4_t src02 = vld1q_f32(src + src_stride + src_stride);
float32x4_t src03 = vld1q_f32(src + src_stride + src_stride + src_stride);
src += src_h_stride;
float32x4_t src10 = vld1q_f32(src);
float32x4_t src11 = vld1q_f32(src + src_stride);
float32x4_t src12 = vld1q_f32(src + src_stride + src_stride);
float32x4_t src13 = vld1q_f32(src + src_stride + src_stride + src_stride);
src += src_h_stride;
float32x4_t src20 = vld1q_f32(src);
float32x4_t src21 = vld1q_f32(src + src_stride);
float32x4_t src22 = vld1q_f32(src + src_stride + src_stride);
float32x4_t src23 = vld1q_f32(src + src_stride + src_stride + src_stride);
src += src_h_stride;
float32x4_t src30 = vld1q_f32(src);
float32x4_t src31 = vld1q_f32(src + src_stride);
float32x4_t src32 = vld1q_f32(src + src_stride + src_stride);
float32x4_t src33 = vld1q_f32(src + src_stride + src_stride + src_stride);
float32x4_t dst00 = vaddq_f32(vaddq_f32(src00, src01), src02);
float32x4_t dst10 = vsubq_f32(vsubq_f32(src01, src02), src03);
float32x4_t dst01 = vaddq_f32(vaddq_f32(src10, src11), src12);
float32x4_t dst11 = vsubq_f32(vsubq_f32(src11, src12), src13);
float32x4_t dst02 = vaddq_f32(vaddq_f32(src20, src21), src22);
float32x4_t dst12 = vsubq_f32(vsubq_f32(src21, src22), src23);
float32x4_t dst03 = vaddq_f32(vaddq_f32(src30, src31), src32);
float32x4_t dst13 = vsubq_f32(vsubq_f32(src31, src32), src33);
float32x4_t dest00 = vaddq_f32(vaddq_f32(dst00, dst01), dst02);
float32x4_t dest10 = vsubq_f32(vsubq_f32(dst01, dst02), dst03);
float32x4_t dest01 = vaddq_f32(vaddq_f32(dst10, dst11), dst12);
float32x4_t dest11 = vsubq_f32(vsubq_f32(dst11, dst12), dst13);
if (bias_value) {
float32x4_t bias = vld1q_f32(bias_value);
dest00 = vaddq_f32(dest00, bias);
dest10 = vaddq_f32(dest10, bias);
dest01 = vaddq_f32(dest01, bias);
dest11 = vaddq_f32(dest11, bias);
}
if (has_relu) {
float32x4_t zeros = vdupq_n_f32(0);
dest00 = vmaxq_f32(dest00, zeros);
dest10 = vmaxq_f32(dest10, zeros);
dest01 = vmaxq_f32(dest01, zeros);
dest11 = vmaxq_f32(dest11, zeros);
}
vst1q_f32(dest, dest00);
vst1q_f32(dest + dest_stride, dest10);
dest += dest_h_stride;
vst1q_f32(dest, dest01);
vst1q_f32(dest + dest_stride, dest11);
}
void weight_trans_c4_8x8(
float* dest, const float* din, int ch_in, int ch_out, void* workspace) {
const float coeff[8][3] = {{1.0f, 0.0f, 0.0f},
{-2.0f / 9, -2.0f / 9, -2.0f / 9},
......@@ -558,6 +1241,63 @@ void weight_trans_c4(
}
}
void weight_trans_c4_4x4(
float* dest, const float* din, int ch_in, int ch_out, void* workspace) {
const float coeff[4][3] = {{1.0f, 0.0f, 0.0f},
{0.5f, 0.5f, 0.5f},
{0.5f, -0.5f, 0.5f},
{0.0f, 0.0f, 1.0f}};
float* ptr_out = static_cast<float*>(workspace);
for (int i = 0; i < ch_out; i++) {
for (int j = 0; j < ch_in; j++) {
const float* kernel0 =
static_cast<const float*>(din) + (i * ch_in + j) * 9;
float* ptr_channel = ptr_out + (i * ch_in + j) * 16;
//! transform kernel, transposed
const float* k0 = kernel0;
const float* k1 = kernel0 + 3;
const float* k2 = kernel0 + 6;
//! h
float tmp[4][3];
for (int i = 0; i < 4; i++) {
tmp[i][0] =
k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2];
tmp[i][1] =
k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2];
tmp[i][2] =
k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2];
}
//! v
for (int j = 0; j < 4; j++) {
float* tmpp = &tmp[j][0];
for (int i = 0; i < 4; i++) {
ptr_channel[j * 4 + i] = tmpp[0] * coeff[i][0] +
tmpp[1] * coeff[i][1] +
tmpp[2] * coeff[i][2];
}
}
}
}
int oc_pad = (ch_out + 3) / 4 * 4;
int ic_pad = (ch_in + 3) / 4 * 4;
int c_stride = ic_pad * oc_pad;
for (int i = 0; i < ch_out * ch_in * 16; ++i) {
int new_c = i % 16;
int new_oc = i / ch_in / 16 / 4;
int new_ic = i / 16 % (ch_in * 4) % ch_in;
int new_inner = i / ch_in / 16 % 4;
int dest_ind =
new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner;
dest[dest_ind] = ptr_out[i];
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -316,7 +316,9 @@ void fill_bias_int8(int* tensor,
int channel_size);
// new winograd
void weight_trans_c4(
void weight_trans_c4_8x8(
float* dest, const float* src, int ic, int oc, void* workspace);
void weight_trans_c4_4x4(
float* dest, const float* src, int ic, int oc, void* workspace);
void conv_compute_6x6_3x3(const float* input,
float* output,
......@@ -331,6 +333,32 @@ void conv_compute_6x6_3x3(const float* input,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_compute_2x2_3x3(const float* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weight,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_compute_2x2_3x3_small(const float* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weight,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -695,7 +695,6 @@ void sgemm_prepack_c4_common(int M,
}
}
}
void sgemm_prepack_c4_small(int M,
int N,
int K,
......@@ -1146,6 +1145,540 @@ void sgemm_prepack_c4_small(int M,
}
}
void sgemm_prepack_c4_small(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
ARMContext* ctx) {
const int m_round = (M + 3) / 4 * 4;
const int k_round = (K + 3) / 4 * 4;
const int mloop = m_round >> 2;
const int lda = 4 * k_round;
const int ldb_byte = 4 * N * sizeof(float);
const int kcnt = k_round >> 2;
#ifdef __aarch64__
float32x4_t vzero = vdupq_n_f32(0.f);
#endif
for (int m = 0; m < mloop; ++m) {
const float* b = B;
int n = N;
#ifdef __aarch64__
for (; n > 7; n -= 8) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
// clang-format off
asm volatile(
"0:\n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
/* load b2, b3 */
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
/* load a2, a3 */
"fmul v8.4s, v16.4s, v0.s[0] \n"
"fmul v9.4s, v16.4s, v1.s[0] \n"
"fmul v10.4s, v16.4s, v2.s[0] \n"
"fmul v11.4s, v16.4s, v3.s[0] \n"
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"prfm pldl1keep, [%[b]] \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
/* load b4, b5 */
"ld1 {v4.4s, v5.4s}, [%[b]], #32 \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load b6, b7 */
"ld1 {v6.4s, v7.4s}, [%[b]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"sub %[b], %[b], #128 \n"
"fmul v12.4s, v16.4s, v4.s[0] \n"
"fmul v13.4s, v16.4s, v5.s[0] \n"
"fmul v14.4s, v16.4s, v6.s[0] \n"
"fmul v15.4s, v16.4s, v7.s[0] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v12.4s, v17.4s, v4.s[1] \n"
"fmla v13.4s, v17.4s, v5.s[1] \n"
"fmla v14.4s, v17.4s, v6.s[1] \n"
"fmla v15.4s, v17.4s, v7.s[1] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v12.4s, v18.4s, v4.s[2] \n"
"fmla v13.4s, v18.4s, v5.s[2] \n"
"fmla v14.4s, v18.4s, v6.s[2] \n"
"fmla v15.4s, v18.4s, v7.s[2] \n"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"fmla v12.4s, v19.4s, v4.s[3] \n"
"fmla v13.4s, v19.4s, v5.s[3] \n"
"fmla v14.4s, v19.4s, v6.s[3] \n"
"fmla v15.4s, v19.4s, v7.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"beq 2f \n"
"1:\n"
/* load b2, b3 */
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
"fmla v8.4s, v16.4s, v0.s[0] \n"
"fmla v9.4s, v16.4s, v1.s[0] \n"
"fmla v10.4s, v16.4s, v2.s[0] \n"
"fmla v11.4s, v16.4s, v3.s[0] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"prfm pldl1keep, [%[b]] \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
/* load b4, b5 */
"ld1 {v4.4s, v5.4s}, [%[b]], #32 \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load b6, b7 */
"ld1 {v6.4s, v7.4s}, [%[b]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"sub %[b], %[b], #128 \n"
"fmla v12.4s, v16.4s, v4.s[0] \n"
"fmla v13.4s, v16.4s, v5.s[0] \n"
"fmla v14.4s, v16.4s, v6.s[0] \n"
"fmla v15.4s, v16.4s, v7.s[0] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v12.4s, v17.4s, v4.s[1] \n"
"fmla v13.4s, v17.4s, v5.s[1] \n"
"fmla v14.4s, v17.4s, v6.s[1] \n"
"fmla v15.4s, v17.4s, v7.s[1] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v12.4s, v18.4s, v4.s[2] \n"
"fmla v13.4s, v18.4s, v5.s[2] \n"
"fmla v14.4s, v18.4s, v6.s[2] \n"
"fmla v15.4s, v18.4s, v7.s[2] \n"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"fmla v12.4s, v19.4s, v4.s[3] \n"
"fmla v13.4s, v19.4s, v5.s[3] \n"
"fmla v14.4s, v19.4s, v6.s[3] \n"
"fmla v15.4s, v19.4s, v7.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"bne 1b \n"
"2:\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte),
[vzero] "w" (vzero)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "cc", "memory"
);
b += 4 * 8;
}
for (; n > 3; n -= 4) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"0:\n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
/* load b0-b3 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
"fmul v8.4s, v16.4s, v0.s[0] \n"
"fmul v9.4s, v16.4s, v1.s[0] \n"
"fmul v10.4s, v16.4s, v2.s[0] \n"
"fmul v11.4s, v16.4s, v3.s[0] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"sub %[b], %[b], #64 \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"beq 2f \n"
"1:\n"
/* load b0-b3 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
"fmla v8.4s, v16.4s, v0.s[0] \n"
"fmla v9.4s, v16.4s, v1.s[0] \n"
"fmla v10.4s, v16.4s, v2.s[0] \n"
"fmla v11.4s, v16.4s, v3.s[0] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"sub %[b], %[b], #64 \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"bne 1b \n"
"2:\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte),
[vzero] "w" (vzero)
: "v0", "v1", "v2", "v3", "v8", "v9",
"v10", "v11", "v16", "v17", "v18",
"v19", "cc", "memory"
);
b += 4 * 4;
}
for (; n > 0; n--) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"0:\n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
/* load b0 */
"ld1 {v0.4s}, [%[b]], #16 \n"
"fmul v8.4s, v16.4s, v0.s[0] \n"
"fmul v9.4s, v17.4s, v0.s[1] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"sub %[b], %[b], #16 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v19.4s, v0.s[3] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"beq 2f \n"
"1:\n"
/* load b0 */
"ld1 {v0.4s}, [%[b]], #16 \n"
"fmla v8.4s, v16.4s, v0.s[0] \n"
"fmla v9.4s, v17.4s, v0.s[1] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"sub %[b], %[b], #16 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v19.4s, v0.s[3] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"bne 1b \n"
"fadd v8.4s, v8.4s, v9.4s \n"
"2:\n"
"st1 {v8.4s}, [%[c]], #16 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte),
[vzero] "w" (vzero)
: "v0", "v8", "v9", "v16", "v17",
"v18", "v19", "cc", "memory"
);
b += 4;
}
#else
for (; n > 7; n -= 8) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
// clang-format off
asm volatile(
"0:\n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vld1.32 {d0-d3}, [%[b]]! \n"
/* load b2, b3 */
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmul.f32 q8, q4, d0[0] \n"
"vmul.f32 q9, q4, d2[0] \n"
"vmul.f32 q10, q4, d4[0] \n"
"vmul.f32 q11, q4, d6[0] \n"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]! \n"
"pld [%[b]] \n"
"vmla.f32 q8, q5, d0[1] \n"
"vmla.f32 q9, q5, d2[1] \n"
"vmla.f32 q10, q5, d4[1] \n"
"vmla.f32 q11, q5, d6[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmla.f32 q8, q6, d1[0] \n"
"vmla.f32 q9, q6, d3[0] \n"
"vmla.f32 q10, q6, d5[0] \n"
"vmla.f32 q11, q6, d7[0] \n"
"pld [%[b], #64] \n"
"vmla.f32 q8, q7, d1[1] \n"
"vmla.f32 q9, q7, d3[1] \n"
/* load b4, b5 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vmla.f32 q10, q7, d5[1] \n"
"vmla.f32 q11, q7, d7[1] \n"
/* load b6, b7 */
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmul.f32 q12, q4, d0[0] \n"
"vmul.f32 q13, q4, d2[0] \n"
"vmul.f32 q14, q4, d4[0] \n"
"vmul.f32 q15, q4, d6[0] \n"
"sub %[b], %[b], #128 \n"
"vmla.f32 q12, q5, d0[1] \n"
"vmla.f32 q13, q5, d2[1] \n"
"vmla.f32 q14, q5, d4[1] \n"
"vmla.f32 q15, q5, d6[1] \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q12, q6, d1[0] \n"
"vmla.f32 q13, q6, d3[0] \n"
"vmla.f32 q14, q6, d5[0] \n"
"vmla.f32 q15, q6, d7[0] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vmla.f32 q12, q7, d1[1] \n"
"vmla.f32 q13, q7, d3[1] \n"
/* load b0, b1 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vmla.f32 q14, q7, d5[1] \n"
"vmla.f32 q15, q7, d7[1] \n"
"beq 2f \n"
"1:\n"
/* load b2, b3 */
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmla.f32 q8, q4, d0[0] \n"
"vmla.f32 q9, q4, d2[0] \n"
"vmla.f32 q10, q4, d4[0] \n"
"vmla.f32 q11, q4, d6[0] \n"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]! \n"
"pld [%[b]] \n"
"vmla.f32 q8, q5, d0[1] \n"
"vmla.f32 q9, q5, d2[1] \n"
"vmla.f32 q10, q5, d4[1] \n"
"vmla.f32 q11, q5, d6[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmla.f32 q8, q6, d1[0] \n"
"vmla.f32 q9, q6, d3[0] \n"
"vmla.f32 q10, q6, d5[0] \n"
"vmla.f32 q11, q6, d7[0] \n"
"pld [%[b], #64] \n"
"vmla.f32 q8, q7, d1[1] \n"
"vmla.f32 q9, q7, d3[1] \n"
/* load b4, b5 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vmla.f32 q10, q7, d5[1] \n"
"vmla.f32 q11, q7, d7[1] \n"
/* load b6, b7 */
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmla.f32 q12, q4, d0[0] \n"
"vmla.f32 q13, q4, d2[0] \n"
"vmla.f32 q14, q4, d4[0] \n"
"vmla.f32 q15, q4, d6[0] \n"
"sub %[b], %[b], #128 \n"
"vmla.f32 q12, q5, d0[1] \n"
"vmla.f32 q13, q5, d2[1] \n"
"vmla.f32 q14, q5, d4[1] \n"
"vmla.f32 q15, q5, d6[1] \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q12, q6, d1[0] \n"
"vmla.f32 q13, q6, d3[0] \n"
"vmla.f32 q14, q6, d5[0] \n"
"vmla.f32 q15, q6, d7[0] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vmla.f32 q12, q7, d1[1] \n"
"vmla.f32 q13, q7, d3[1] \n"
/* load b0, b1 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vmla.f32 q14, q7, d5[1] \n"
"vmla.f32 q15, q7, d7[1] \n"
"bne 1b \n"
"2:\n"
"vst1.32 {d16-d19}, [%[c]]! \n"
"vst1.32 {d20-d23}, [%[c]]! \n"
"vst1.32 {d24-d27}, [%[c]]! \n"
"vst1.32 {d28-d31}, [%[c]]! \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4", "q5",
"q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "cc", "memory"
);
b += 4 * 8;
}
for (; n > 3; n -= 4) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"0:\n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmul.f32 q8, q4, d0[0] \n"
"vmul.f32 q9, q4, d2[0] \n"
"vmul.f32 q10, q4, d4[0] \n"
"vmul.f32 q11, q4, d6[0] \n"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!\n"
"sub %[b], %[b], #64 \n"
"vmla.f32 q8, q5, d0[1] \n"
"vmla.f32 q9, q5, d2[1] \n"
"vmla.f32 q10, q5, d4[1] \n"
"vmla.f32 q11, q5, d6[1] \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q8, q6, d1[0] \n"
"vmla.f32 q9, q6, d3[0] \n"
"vmla.f32 q10, q6, d5[0] \n"
"vmla.f32 q11, q6, d7[0] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vmla.f32 q8, q7, d1[1] \n"
"vmla.f32 q9, q7, d3[1] \n"
"vmla.f32 q10, q7, d5[1] \n"
"vmla.f32 q11, q7, d7[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"beq 2f \n"
"1:\n"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmla.f32 q8, q4, d0[0] \n"
"vmla.f32 q9, q4, d2[0] \n"
"vmla.f32 q10, q4, d4[0] \n"
"vmla.f32 q11, q4, d6[0] \n"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!\n"
"sub %[b], %[b], #64 \n"
"vmla.f32 q8, q5, d0[1] \n"
"vmla.f32 q9, q5, d2[1] \n"
"vmla.f32 q10, q5, d4[1] \n"
"vmla.f32 q11, q5, d6[1] \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q8, q6, d1[0] \n"
"vmla.f32 q9, q6, d3[0] \n"
"vmla.f32 q10, q6, d5[0] \n"
"vmla.f32 q11, q6, d7[0] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vmla.f32 q8, q7, d1[1] \n"
"vmla.f32 q9, q7, d3[1] \n"
"vmla.f32 q10, q7, d5[1] \n"
"vmla.f32 q11, q7, d7[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"bne 1b \n"
"2:\n"
"vst1.32 {d16-d19}, [%[c]]!\n"
"vst1.32 {d20-d23}, [%[c]]!\n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4", "q5",
"q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "cc", "memory"
);
b += 4 * 4;
}
for (; n > 0; n--) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"0:\n"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]! \n"
/* load b0 */
"vld1.32 {d0-d1}, [%[b]]! \n"
"vmul.f32 q5, q1, d0[0] \n"
"vmul.f32 q6, q2, d0[1] \n"
/* load a2, a3 */
"vld1.32 {d6-d9}, [%[a]]! \n"
"sub %[b], %[b], #16 \n"
"subs %[cnt], %[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q5, q3, d1[0] \n"
"vmla.f32 q6, q4, d1[1] \n"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]! \n"
"beq 2f \n"
"1:\n"
/* load b0 */
"vld1.32 {d0-d1}, [%[b]]! \n"
"vmla.f32 q5, q1, d0[0] \n"
"vmla.f32 q6, q2, d0[1] \n"
/* load a2, a3 */
"vld1.32 {d6-d9}, [%[a]]! \n"
"sub %[b], %[b], #16 \n"
"subs %[cnt], %[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q5, q3, d1[0] \n"
"vmla.f32 q6, q4, d1[1] \n"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]! \n"
"bne 1b \n"
"vadd.f32 q5, q5, q6 \n"
"2:\n"
"vst1.32 {d10-d11}, [%[c]]!\n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4",
"q5", "q6", "q7", "q8", "cc", "memory"
);
// clang-format on
b += 4;
}
#endif
A_packed += lda;
}
}
void sgemm_prepack_c4(int M,
int N,
int K,
......
......@@ -47,6 +47,13 @@ void sgemm_prepack_c4_small(int M,
bool has_bias,
bool has_relu,
ARMContext* ctx);
void sgemm_prepack_c4_small(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -68,19 +68,9 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) {
bool use_winograd =
(threads == 1 && oc >= 4 && ic >= 4 && hout >= 6 && wout >= 6 &&
pads_equal) ||
(oc >= 32 && ic >= 32 && hout >= 16 && wout >= 16 && pads_equal);
if (use_winograd) {
/// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv";
} else {
/// direct conv impl
impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv";
}
/// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv";
} else if (param.groups == 1 && kw == 3 && stride == 2 &&
chin * chout < 4 * hin * win && kps_equal && no_dilation) {
/// direct conv impl
......
......@@ -43,79 +43,47 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
int oh = o_dims[2];
int ow = o_dims[3];
int tile_block = 8;
#ifdef __aarch64__
tile_block = 16;
#endif
int parallel_threads =
(((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block;
if (threads <= 2 && parallel_threads >= threads) {
if (last_kernel_is_c4_ == 1) {
choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false;
if (choose_small_) {
wino_iw = 4;
if (last_function_ == 0) {
return;
}
last_kernel_is_c4_ = 1;
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
int oc_pad = (oc + 3) / 4 * 4;
int ic_pad = (ic + 3) / 4 * 4;
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 256 + 512) * threads;
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
weights_.Resize({1, 1, 1, 64 * oc_pad * ic_pad});
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
auto weights_data_ = weights_.mutable_data<float>();
lite::arm::math::weight_trans_c4(
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
free(trans_tmp_ptr);
last_function_ = 0;
} else {
if (last_kernel_is_c4_ == 0) {
wino_iw = 8;
if (last_function_ == 1) {
return;
}
last_kernel_is_c4_ = 0;
int tile_w = (ow + 5) / 6;
int tile_h = (oh + 5) / 6;
int size_tile = tile_h * tile_w;
int size_trans_channel = 8 * 8 * size_tile;
int max_ch = ic > oc ? ic : oc;
const int n_wino = size_tile;
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
const int m_wino = oc;
int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_wino + hblock - 1) / hblock);
weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic});
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
auto weights_wino =
static_cast<float*>(malloc(sizeof(float) * 8 * 8 * oc * ic));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
lite::arm::math::winograd_transform_weights(
weights_wino, param.filter->data<float>(), oc, ic, trans_tmp_ptr);
auto weights_trans = weights_.mutable_data<float>();
for (int i = 0; i < 64; ++i) {
float* packed_weights = weights_trans + i * m_round * ic;
const float* weights_wino_ptr = weights_wino + i * oc * ic;
lite::arm::math::prepackA(packed_weights,
weights_wino_ptr,
1.f,
ic,
0,
m_wino,
0,
ic,
false,
&ctx);
}
free(trans_tmp_ptr);
free(weights_wino);
last_function_ = 1;
}
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
int oc_pad = (oc + 3) / 4 * 4;
int ic_pad = (ic + 3) / 4 * 4;
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw +
8 * wino_iw * wino_iw) *
threads;
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<float>();
if (!choose_small_) {
lite::arm::math::weight_trans_c4_8x8(
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
} else {
lite::arm::math::weight_trans_c4_4x4(
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
}
free(trans_tmp_ptr);
last_shape_ = x_dims;
}
......@@ -145,14 +113,7 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
int ow = o_dims[3];
int oc = o_dims[1];
int tile_block = 8;
#ifdef __aarch64__
tile_block = 16;
#endif
int threads = ctx.threads();
int parallel_threads =
(((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block;
if (threads <= 2 && parallel_threads >= threads) {
if (!choose_small_) {
lite::arm::math::conv_compute_6x6_3x3(i_data,
o_data,
bs,
......@@ -167,19 +128,38 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
param,
&ctx);
} else {
lite::arm::math::conv_winograd3x3(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
int tile_block = 8;
int block_count =
(((ow + 1) / 2) * ((oh + 1) / 2) + tile_block - 1) / tile_block;
if (block_count != 1) {
lite::arm::math::conv_compute_2x2_3x3(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
} else {
lite::arm::math::conv_compute_2x2_3x3_small(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
}
}
}
......
......@@ -40,7 +40,9 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
Tensor weights_;
DDim last_shape_;
int workspace_size_{0};
int last_kernel_is_c4_{-1};
int last_function_{-1};
bool choose_small_{false};
int wino_iw{8};
};
} // namespace arm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册