未验证 提交 db2ab554 编写于 作者: H HappyAngel 提交者: GitHub

fix conv_3x3s1_dw v7-compute nan problem (#4309)

* fix conv_3x3s1_dw v7-compute nan. test=develop

* fix compute. tets=develop

* set sgemm basic_test is false. test=develop
上级 0108c64e
......@@ -25,6 +25,73 @@ namespace paddle {
namespace lite {
namespace arm {
namespace math {
void conv_3x3s1_depthwise_fp32_bias(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s1_depthwise_fp32_relu(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s1_depthwise_fp32_relu6(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s1_depthwise_fp32_leakyRelu(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx);
// clang-format off
#ifdef __aarch64__
#define COMPUTE \
......@@ -335,7 +402,6 @@ namespace math {
"ldr r0, [%[outl]] @ load outc00 to r0\n" \
"vmla.f32 q12, q5, q0 @ w8 * inr32\n" \
"vmla.f32 q13, q5, q1 @ w8 * inr33\n" \
"ldr r5, [%[outl], #36] @ load flag_relu to r5\n" \
"vmla.f32 q14, q5, q2 @ w8 * inr34\n" \
"vmla.f32 q15, q5, q3 @ w8 * inr35\n" \
"ldr r1, [%[outl], #4] @ load outc10 to r1\n" \
......@@ -406,7 +472,6 @@ namespace math {
"vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \
"vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \
"vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \
"ldr r5, [%[outl], #20] @ load outc11 to r5\n" \
"vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \
"vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \
"vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \
......@@ -417,12 +482,13 @@ namespace math {
"vst1.32 {d18-d19}, [r1] @ save outc10\n" \
"vst1.32 {d20-d21}, [r2] @ save outc20\n" \
"vst1.32 {d22-d23}, [r3] @ save outc30\n" \
"ldr r0, [%[outl], #20] @ load outc11 to r5\n" \
"ldr r1, [%[outl], #24] @ load outc21 to r0\n" \
"ldr r2, [%[outl], #28] @ load outc31 to r1\n" \
"vst1.32 {d24-d25}, [r4] @ save outc01\n" \
"vst1.32 {d26-d27}, [r5] @ save outc11\n" \
"ldr r0, [%[outl], #24] @ load outc21 to r0\n" \
"ldr r1, [%[outl], #28] @ load outc31 to r1\n" \
"vst1.32 {d28-d29}, [r0] @ save outc21\n" \
"vst1.32 {d30-d31}, [r1] @ save outc31\n" \
"vst1.32 {d26-d27}, [r0] @ save outc11\n" \
"vst1.32 {d28-d29}, [r1] @ save outc21\n" \
"vst1.32 {d30-d31}, [r2] @ save outc31\n" \
"b 3f @ branch end\n" \
"2: \n" \
"vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \
......@@ -436,31 +502,256 @@ namespace math {
"3: \n"
#endif
// clang-format on
void act_switch_3x3s1(const float* inr0,
const float* inr1,
const float* inr2,
const float* inr3,
float* out0,
const float* weight_c,
float flag_mask,
void* outl_ptr,
float32x4_t w0,
float32x4_t w1,
float32x4_t w2,
float32x4_t w3,
float32x4_t w4,
float32x4_t w5,
float32x4_t w6,
float32x4_t w7,
float32x4_t w8,
float32x4_t vbias,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
void conv_3x3s1_depthwise_fp32(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx) {
float six_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f};
float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
conv_3x3s1_depthwise_fp32_relu(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
win,
weights,
bias,
relu_ptr,
six_ptr,
scale_ptr,
param,
ctx);
break;
case lite_api::ActivationType::kRelu6:
six_ptr[0] = act_param.Relu_clipped_coef;
six_ptr[1] = act_param.Relu_clipped_coef;
six_ptr[2] = act_param.Relu_clipped_coef;
six_ptr[3] = act_param.Relu_clipped_coef;
conv_3x3s1_depthwise_fp32_relu6(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
win,
weights,
bias,
relu_ptr,
six_ptr,
scale_ptr,
param,
ctx);
break;
case lite_api::ActivationType::kLeakyRelu:
scale_ptr[0] = act_param.Leaky_relu_alpha;
scale_ptr[1] = act_param.Leaky_relu_alpha;
scale_ptr[2] = act_param.Leaky_relu_alpha;
scale_ptr[3] = act_param.Leaky_relu_alpha;
conv_3x3s1_depthwise_fp32_leakyRelu(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
win,
weights,
bias,
relu_ptr,
six_ptr,
scale_ptr,
param,
ctx);
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
conv_3x3s1_depthwise_fp32_bias(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
win,
weights,
bias,
relu_ptr,
six_ptr,
scale_ptr,
param,
ctx);
}
}
void conv_3x3s1_depthwise_fp32_bias(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx) {
int threads = ctx->threads();
auto paddings = *param.paddings;
const int pad_h = paddings[0];
const int pad_w = paddings[2];
const int out_c_block = 4;
const int out_h_kernel = 2;
const int out_w_kernel = 4;
const int win_ext = ow + 2;
const int ow_round = ROUNDUP(ow, 4);
const int win_round = ROUNDUP(win_ext, 4);
const int hin_round = oh + 2;
const int prein_size = win_round * hin_round * out_c_block;
auto workspace_size =
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_bias = param.bias != nullptr;
/// get workspace
LOG(INFO) << "conv_3x3s1_depthwise_fp32_bias: ";
float* ptr_zero = ctx->workspace_data<float>();
memset(ptr_zero, 0, sizeof(float) * win_round);
float* ptr_write = ptr_zero + win_round;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int ws = -pad_w;
int we = ws + win_round;
int hs = -pad_h;
int he = hs + hin_round;
int w_loop = ow_round / 4;
auto remain = w_loop * 4 - ow;
bool flag_remain = remain > 0;
remain = 4 - remain;
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc; c += out_c_block) {
#ifdef ARM_WITH_OMP
float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size;
#else
float* pre_din = ptr_write + ow_round;
#endif
/// const array size
float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT
prepack_input_nxwc4_dw(
din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero);
const float* weight_c = weights + c * 9; // kernel_w * kernel_h
float* dout_c00 = dout_batch + c * size_out_channel;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
float32x4_t vbias = vld1q_f32(bias_local);
#ifdef __aarch64__
asm volatile(COMPUTE RELU STORE
float32x4_t w0 = vld1q_f32(weight_c); // w0, v23
float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24
float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25
float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26
float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27
float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28
float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29
float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30
float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31
#endif
for (int h = 0; h < oh; h += out_h_kernel) {
float* outc00 = dout_c00 + h * ow;
float* outc01 = outc00 + ow;
float* outc10 = outc00 + size_out_channel;
float* outc11 = outc10 + ow;
float* outc20 = outc10 + size_out_channel;
float* outc21 = outc20 + ow;
float* outc30 = outc20 + size_out_channel;
float* outc31 = outc30 + ow;
const float* inr0 = pre_din + h * row_len;
const float* inr1 = inr0 + row_len;
const float* inr2 = inr1 + row_len;
const float* inr3 = inr2 + row_len;
if (c + out_c_block > oc) {
switch (c + out_c_block - oc) {
case 3: // outc10-outc30 is ptr_write and extra
outc10 = ptr_write;
outc11 = ptr_write;
case 2: // outc20-outc30 is ptr_write and extra
outc20 = ptr_write;
outc21 = ptr_write;
case 1: // outc30 is ptr_write and extra
outc30 = ptr_write;
outc31 = ptr_write;
default:
break;
}
}
if (h + out_h_kernel > oh) {
outc01 = ptr_write;
outc11 = ptr_write;
outc21 = ptr_write;
outc31 = ptr_write;
}
float* outl[] = {outc00,
outc10,
outc20,
outc30,
outc01,
outc11,
outc21,
outc31,
reinterpret_cast<float*>(bias_local),
reinterpret_cast<float*>(relu_ptr),
reinterpret_cast<float*>(six_ptr),
reinterpret_cast<float*>(scale_ptr)};
void* outl_ptr = reinterpret_cast<void*>(outl);
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
#ifdef __aarch64__
asm volatile(COMPUTE STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
......@@ -509,9 +800,7 @@ void act_switch_3x3s1(const float* inr0,
"x6",
"x7");
#else
#if 1 // def LITE_WITH_ARM_CLANG
#else
asm volatile(COMPUTE RELU STORE
asm volatile(COMPUTE STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
......@@ -541,102 +830,175 @@ void act_switch_3x3s1(const float* inr0,
"r1",
"r2",
"r3",
"r4",
"r5");
"r4");
#endif
#endif
break;
case lite_api::ActivationType::kRelu6:
#ifdef __aarch64__
asm volatile(COMPUTE RELU RELU6 STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[inr3] "+r"(inr3),
[out] "+r"(out0)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
#if 1 // def LITE_WITH_ARM_CLANG
outl[0] += 4;
outl[1] += 4;
outl[2] += 4;
outl[3] += 4;
outl[4] += 4;
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
if (flag_mask) {
memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float));
memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float));
memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float));
memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float));
memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float));
memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float));
}
}
}
}
}
}
void conv_3x3s1_depthwise_fp32_relu(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx) {
int threads = ctx->threads();
auto paddings = *param.paddings;
const int pad_h = paddings[0];
const int pad_w = paddings[2];
const int out_c_block = 4;
const int out_h_kernel = 2;
const int out_w_kernel = 4;
const int win_ext = ow + 2;
const int ow_round = ROUNDUP(ow, 4);
const int win_round = ROUNDUP(win_ext, 4);
const int hin_round = oh + 2;
const int prein_size = win_round * hin_round * out_c_block;
auto workspace_size =
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_bias = param.bias != nullptr;
/// get workspace
float* ptr_zero = ctx->workspace_data<float>();
memset(ptr_zero, 0, sizeof(float) * win_round);
float* ptr_write = ptr_zero + win_round;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int ws = -pad_w;
int we = ws + win_round;
int hs = -pad_h;
int he = hs + hin_round;
int w_loop = ow_round / 4;
auto remain = w_loop * 4 - ow;
bool flag_remain = remain > 0;
remain = 4 - remain;
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc; c += out_c_block) {
#ifdef ARM_WITH_OMP
float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size;
#else
asm volatile(COMPUTE RELU RELU6 STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4",
"r5");
float* pre_din = ptr_write + ow_round;
#endif
/// const array size
float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT
prepack_input_nxwc4_dw(
din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero);
const float* weight_c = weights + c * 9; // kernel_w * kernel_h
float* dout_c00 = dout_batch + c * size_out_channel;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
float32x4_t vbias = vld1q_f32(bias_local);
#ifdef __aarch64__
float32x4_t w0 = vld1q_f32(weight_c); // w0, v23
float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24
float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25
float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26
float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27
float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28
float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29
float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30
float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31
#endif
for (int h = 0; h < oh; h += out_h_kernel) {
float* outc00 = dout_c00 + h * ow;
float* outc01 = outc00 + ow;
float* outc10 = outc00 + size_out_channel;
float* outc11 = outc10 + ow;
float* outc20 = outc10 + size_out_channel;
float* outc21 = outc20 + ow;
float* outc30 = outc20 + size_out_channel;
float* outc31 = outc30 + ow;
const float* inr0 = pre_din + h * row_len;
const float* inr1 = inr0 + row_len;
const float* inr2 = inr1 + row_len;
const float* inr3 = inr2 + row_len;
if (c + out_c_block > oc) {
switch (c + out_c_block - oc) {
case 3: // outc10-outc30 is ptr_write and extra
outc10 = ptr_write;
outc11 = ptr_write;
case 2: // outc20-outc30 is ptr_write and extra
outc20 = ptr_write;
outc21 = ptr_write;
case 1: // outc30 is ptr_write and extra
outc30 = ptr_write;
outc31 = ptr_write;
default:
break;
case lite_api::ActivationType::kLeakyRelu:
}
}
if (h + out_h_kernel > oh) {
outc01 = ptr_write;
outc11 = ptr_write;
outc21 = ptr_write;
outc31 = ptr_write;
}
float* outl[] = {outc00,
outc10,
outc20,
outc30,
outc01,
outc11,
outc21,
outc31,
reinterpret_cast<float*>(bias_local),
reinterpret_cast<float*>(relu_ptr),
reinterpret_cast<float*>(six_ptr),
reinterpret_cast<float*>(scale_ptr)};
void* outl_ptr = reinterpret_cast<void*>(outl);
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
#ifdef __aarch64__
asm volatile(COMPUTE LEAKY_RELU STORE
asm volatile(COMPUTE RELU STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
......@@ -685,9 +1047,7 @@ void act_switch_3x3s1(const float* inr0,
"x6",
"x7");
#else
#if 1 // def LITE_WITH_ARM_CLANG
#else
asm volatile(COMPUTE LEAKY_RELU STORE
asm volatile(COMPUTE RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
......@@ -717,19 +1077,175 @@ void act_switch_3x3s1(const float* inr0,
"r1",
"r2",
"r3",
"r4",
"r5");
"r4");
#endif
outl[0] += 4;
outl[1] += 4;
outl[2] += 4;
outl[3] += 4;
outl[4] += 4;
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
if (flag_mask) {
memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float));
memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float));
memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float));
memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float));
memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float));
memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float));
}
}
}
}
}
}
void conv_3x3s1_depthwise_fp32_relu6(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx) {
int threads = ctx->threads();
auto paddings = *param.paddings;
const int pad_h = paddings[0];
const int pad_w = paddings[2];
const int out_c_block = 4;
const int out_h_kernel = 2;
const int out_w_kernel = 4;
const int win_ext = ow + 2;
const int ow_round = ROUNDUP(ow, 4);
const int win_round = ROUNDUP(win_ext, 4);
const int hin_round = oh + 2;
const int prein_size = win_round * hin_round * out_c_block;
auto workspace_size =
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_bias = param.bias != nullptr;
/// get workspace
float* ptr_zero = ctx->workspace_data<float>();
memset(ptr_zero, 0, sizeof(float) * win_round);
float* ptr_write = ptr_zero + win_round;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int ws = -pad_w;
int we = ws + win_round;
int hs = -pad_h;
int he = hs + hin_round;
int w_loop = ow_round / 4;
auto remain = w_loop * 4 - ow;
bool flag_remain = remain > 0;
remain = 4 - remain;
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc; c += out_c_block) {
#ifdef ARM_WITH_OMP
float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size;
#else
float* pre_din = ptr_write + ow_round;
#endif
break;
/// const array size
float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT
prepack_input_nxwc4_dw(
din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero);
const float* weight_c = weights + c * 9; // kernel_w * kernel_h
float* dout_c00 = dout_batch + c * size_out_channel;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
float32x4_t vbias = vld1q_f32(bias_local);
#ifdef __aarch64__
float32x4_t w0 = vld1q_f32(weight_c); // w0, v23
float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24
float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25
float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26
float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27
float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28
float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29
float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30
float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31
#endif
for (int h = 0; h < oh; h += out_h_kernel) {
float* outc00 = dout_c00 + h * ow;
float* outc01 = outc00 + ow;
float* outc10 = outc00 + size_out_channel;
float* outc11 = outc10 + ow;
float* outc20 = outc10 + size_out_channel;
float* outc21 = outc20 + ow;
float* outc30 = outc20 + size_out_channel;
float* outc31 = outc30 + ow;
const float* inr0 = pre_din + h * row_len;
const float* inr1 = inr0 + row_len;
const float* inr2 = inr1 + row_len;
const float* inr3 = inr2 + row_len;
if (c + out_c_block > oc) {
switch (c + out_c_block - oc) {
case 3: // outc10-outc30 is ptr_write and extra
outc10 = ptr_write;
outc11 = ptr_write;
case 2: // outc20-outc30 is ptr_write and extra
outc20 = ptr_write;
outc21 = ptr_write;
case 1: // outc30 is ptr_write and extra
outc30 = ptr_write;
outc31 = ptr_write;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
break;
}
} else {
}
if (h + out_h_kernel > oh) {
outc01 = ptr_write;
outc11 = ptr_write;
outc21 = ptr_write;
outc31 = ptr_write;
}
float* outl[] = {outc00,
outc10,
outc20,
outc30,
outc01,
outc11,
outc21,
outc31,
reinterpret_cast<float*>(bias_local),
reinterpret_cast<float*>(relu_ptr),
reinterpret_cast<float*>(six_ptr),
reinterpret_cast<float*>(scale_ptr)};
void* outl_ptr = reinterpret_cast<void*>(outl);
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
#ifdef __aarch64__
asm volatile(COMPUTE STORE
asm volatile(COMPUTE RELU RELU6 STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
......@@ -778,9 +1294,7 @@ void act_switch_3x3s1(const float* inr0,
"x6",
"x7");
#else
#if 1 // def LITE_WITH_ARM_CLANG
#else
asm volatile(COMPUTE STORE
asm volatile(COMPUTE RELU RELU6 STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
......@@ -810,13 +1324,33 @@ void act_switch_3x3s1(const float* inr0,
"r1",
"r2",
"r3",
"r4",
"r5");
#endif
"r4");
#endif
outl[0] += 4;
outl[1] += 4;
outl[2] += 4;
outl[3] += 4;
outl[4] += 4;
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
if (flag_mask) {
memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float));
memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float));
memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float));
memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float));
memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float));
memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float));
}
}
}
}
}
}
void conv_3x3s1_depthwise_fp32(const float* i_data,
void conv_3x3s1_depthwise_fp32_leakyRelu(const float* i_data,
float* o_data,
int bs,
int oc,
......@@ -827,8 +1361,10 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx) {
int threads = ctx->threads();
......@@ -869,31 +1405,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
float six_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f};
float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
break;
case lite_api::ActivationType::kRelu6:
six_ptr[0] = act_param.Relu_clipped_coef;
six_ptr[1] = act_param.Relu_clipped_coef;
six_ptr[2] = act_param.Relu_clipped_coef;
six_ptr[3] = act_param.Relu_clipped_coef;
break;
case lite_api::ActivationType::kLeakyRelu:
scale_ptr[0] = act_param.Leaky_relu_alpha;
scale_ptr[1] = act_param.Leaky_relu_alpha;
scale_ptr[2] = act_param.Leaky_relu_alpha;
scale_ptr[3] = act_param.Leaky_relu_alpha;
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
......@@ -944,13 +1455,13 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
const float* inr3 = inr2 + row_len;
if (c + out_c_block > oc) {
switch (c + out_c_block - oc) {
case 3:
case 3: // outc10-outc30 is ptr_write and extra
outc10 = ptr_write;
outc11 = ptr_write;
case 2:
case 2: // outc20-outc30 is ptr_write and extra
outc20 = ptr_write;
outc21 = ptr_write;
case 1:
case 1: // outc30 is ptr_write and extra
outc30 = ptr_write;
outc31 = ptr_write;
default:
......@@ -981,48 +1492,86 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
#ifdef __aarch64__
act_switch_3x3s1(inr0,
inr1,
inr2,
inr3,
out0,
weight_c,
flag_mask,
outl_ptr,
w0,
w1,
w2,
w3,
w4,
w5,
w6,
w7,
w8,
vbias,
act_param);
#else
#if 1 // def LITE_WITH_ARM_CLANG
asm volatile(COMPUTE LEAKY_RELU STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[inr3] "+r"(inr3),
[out] "+r"(out0)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
act_switch_3x3s1(inr0,
inr1,
inr2,
inr3,
out0,
weight_c,
flag_mask,
outl_ptr,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
act_param);
#endif
asm volatile(COMPUTE LEAKY_RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4");
#endif
outl[0] += 4;
outl[1] += 4;
......@@ -1032,10 +1581,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
inr0 += 16;
inr1 += 16;
inr2 += 16;
inr3 += 16;
if (flag_mask) {
memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
......@@ -1051,7 +1596,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -620,8 +620,10 @@ void conv_depthwise_3x3_fp32(const void* din,
int pad = pad_w;
bool flag_bias = param.bias != nullptr;
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
bool ch_four = ch_in <= 4 * w_in;
if (stride == 1) {
if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1]
if (ch_four && pads_less && (pad_h == pad_w) &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -638,7 +640,6 @@ void conv_depthwise_3x3_fp32(const void* din,
act_param,
ctx);
} else {
#ifdef __aarch64__
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -653,30 +654,10 @@ void conv_depthwise_3x3_fp32(const void* din,
param,
act_param,
ctx);
#else
#ifdef LITE_WITH_ARM_CLANG
LOG(FATAL) << "fp32 depthwise conv3x3s1px doesnot support in v7-clang, "
"this can run in basic";
#else
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
h_out,
w_out,
ch_in,
h_in,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
#endif
#endif
}
} else if (stride == 2) {
if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1]
if (ch_four && pads_less && pad_h == pad_w &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......
......@@ -59,12 +59,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2);
bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2);
#ifdef __aarch64__
#else
bool flag =
(stride == 1 && (paddings[0] > 1 || paddings[2] > 1)) ? false : true;
flag_dw_3x3 = flag_dw_3x3 && flag;
#endif
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
/// select conv impl
......
......@@ -28,11 +28,15 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims();
auto kw = w_dims[3];
auto channel = w_dims[0];
auto hin = param.x->dims()[2];
auto win = param.x->dims()[3];
auto paddings = *param.paddings;
bool ch_four = channel <= 4 * win;
// select dw conv kernel
if (kw == 3) {
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (pads_less && paddings[0] == paddings[2] &&
if (ch_four && pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) {
flag_trans_weights_ = false;
} else {
......@@ -398,6 +402,14 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
w_scale_.data());
}
#ifdef LITE_WITH_PROFILE
template <>
void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -39,6 +39,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
#ifdef LITE_WITH_ARM
// sgemm_test wiil not be operated except that it's
// on arm backend.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册