未验证 提交 db2ab554 编写于 作者: H HappyAngel 提交者: GitHub

fix conv_3x3s1_dw v7-compute nan problem (#4309)

* fix conv_3x3s1_dw v7-compute nan. test=develop

* fix compute. tets=develop

* set sgemm basic_test is false. test=develop
上级 0108c64e
...@@ -25,6 +25,73 @@ namespace paddle { ...@@ -25,6 +25,73 @@ namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void conv_3x3s1_depthwise_fp32_bias(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s1_depthwise_fp32_relu(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s1_depthwise_fp32_relu6(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s1_depthwise_fp32_leakyRelu(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx);
// clang-format off // clang-format off
#ifdef __aarch64__ #ifdef __aarch64__
#define COMPUTE \ #define COMPUTE \
...@@ -335,7 +402,6 @@ namespace math { ...@@ -335,7 +402,6 @@ namespace math {
"ldr r0, [%[outl]] @ load outc00 to r0\n" \ "ldr r0, [%[outl]] @ load outc00 to r0\n" \
"vmla.f32 q12, q5, q0 @ w8 * inr32\n" \ "vmla.f32 q12, q5, q0 @ w8 * inr32\n" \
"vmla.f32 q13, q5, q1 @ w8 * inr33\n" \ "vmla.f32 q13, q5, q1 @ w8 * inr33\n" \
"ldr r5, [%[outl], #36] @ load flag_relu to r5\n" \
"vmla.f32 q14, q5, q2 @ w8 * inr34\n" \ "vmla.f32 q14, q5, q2 @ w8 * inr34\n" \
"vmla.f32 q15, q5, q3 @ w8 * inr35\n" \ "vmla.f32 q15, q5, q3 @ w8 * inr35\n" \
"ldr r1, [%[outl], #4] @ load outc10 to r1\n" \ "ldr r1, [%[outl], #4] @ load outc10 to r1\n" \
...@@ -406,7 +472,6 @@ namespace math { ...@@ -406,7 +472,6 @@ namespace math {
"vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \ "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \
"vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \ "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \
"vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \ "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \
"ldr r5, [%[outl], #20] @ load outc11 to r5\n" \
"vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \ "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \
"vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \ "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \
"vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \ "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \
...@@ -417,12 +482,13 @@ namespace math { ...@@ -417,12 +482,13 @@ namespace math {
"vst1.32 {d18-d19}, [r1] @ save outc10\n" \ "vst1.32 {d18-d19}, [r1] @ save outc10\n" \
"vst1.32 {d20-d21}, [r2] @ save outc20\n" \ "vst1.32 {d20-d21}, [r2] @ save outc20\n" \
"vst1.32 {d22-d23}, [r3] @ save outc30\n" \ "vst1.32 {d22-d23}, [r3] @ save outc30\n" \
"ldr r0, [%[outl], #20] @ load outc11 to r5\n" \
"ldr r1, [%[outl], #24] @ load outc21 to r0\n" \
"ldr r2, [%[outl], #28] @ load outc31 to r1\n" \
"vst1.32 {d24-d25}, [r4] @ save outc01\n" \ "vst1.32 {d24-d25}, [r4] @ save outc01\n" \
"vst1.32 {d26-d27}, [r5] @ save outc11\n" \ "vst1.32 {d26-d27}, [r0] @ save outc11\n" \
"ldr r0, [%[outl], #24] @ load outc21 to r0\n" \ "vst1.32 {d28-d29}, [r1] @ save outc21\n" \
"ldr r1, [%[outl], #28] @ load outc31 to r1\n" \ "vst1.32 {d30-d31}, [r2] @ save outc31\n" \
"vst1.32 {d28-d29}, [r0] @ save outc21\n" \
"vst1.32 {d30-d31}, [r1] @ save outc31\n" \
"b 3f @ branch end\n" \ "b 3f @ branch end\n" \
"2: \n" \ "2: \n" \
"vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \ "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \
...@@ -436,291 +502,86 @@ namespace math { ...@@ -436,291 +502,86 @@ namespace math {
"3: \n" "3: \n"
#endif #endif
// clang-format on // clang-format on
void act_switch_3x3s1(const float* inr0, void conv_3x3s1_depthwise_fp32(const float* i_data,
const float* inr1, float* o_data,
const float* inr2, int bs,
const float* inr3, int oc,
float* out0, int oh,
const float* weight_c, int ow,
float flag_mask, int ic,
void* outl_ptr, int ih,
float32x4_t w0, int win,
float32x4_t w1, const float* weights,
float32x4_t w2, const float* bias,
float32x4_t w3, const operators::ConvParam& param,
float32x4_t w4, const operators::ActivationParam act_param,
float32x4_t w5, ARMContext* ctx) {
float32x4_t w6, float six_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t w7, float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f};
float32x4_t w8, float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vbias, if (act_param.has_active) {
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
switch (act_param.active_type) { switch (act_param.active_type) {
case lite_api::ActivationType::kRelu: case lite_api::ActivationType::kRelu:
#ifdef __aarch64__ conv_3x3s1_depthwise_fp32_relu(i_data,
asm volatile(COMPUTE RELU STORE o_data,
: [inr0] "+r"(inr0), bs,
[inr1] "+r"(inr1), oc,
[inr2] "+r"(inr2), oh,
[inr3] "+r"(inr3), ow,
[out] "+r"(out0) ic,
: [w0] "w"(w0), ih,
[w1] "w"(w1), win,
[w2] "w"(w2), weights,
[w3] "w"(w3), bias,
[w4] "w"(w4), relu_ptr,
[w5] "w"(w5), six_ptr,
[w6] "w"(w6), scale_ptr,
[w7] "w"(w7), param,
[w8] "w"(w8), ctx);
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
#if 1 // def LITE_WITH_ARM_CLANG
#else
asm volatile(COMPUTE RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4",
"r5");
#endif
#endif
break; break;
case lite_api::ActivationType::kRelu6: case lite_api::ActivationType::kRelu6:
#ifdef __aarch64__ six_ptr[0] = act_param.Relu_clipped_coef;
asm volatile(COMPUTE RELU RELU6 STORE six_ptr[1] = act_param.Relu_clipped_coef;
: [inr0] "+r"(inr0), six_ptr[2] = act_param.Relu_clipped_coef;
[inr1] "+r"(inr1), six_ptr[3] = act_param.Relu_clipped_coef;
[inr2] "+r"(inr2), conv_3x3s1_depthwise_fp32_relu6(i_data,
[inr3] "+r"(inr3), o_data,
[out] "+r"(out0) bs,
: [w0] "w"(w0), oc,
[w1] "w"(w1), oh,
[w2] "w"(w2), ow,
[w3] "w"(w3), ic,
[w4] "w"(w4), ih,
[w5] "w"(w5), win,
[w6] "w"(w6), weights,
[w7] "w"(w7), bias,
[w8] "w"(w8), relu_ptr,
[vbias] "w"(vbias), six_ptr,
[outl] "r"(outl_ptr), scale_ptr,
[flag_mask] "r"(flag_mask) param,
: "cc", ctx);
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
#if 1 // def LITE_WITH_ARM_CLANG
#else
asm volatile(COMPUTE RELU RELU6 STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4",
"r5");
#endif
#endif
break; break;
case lite_api::ActivationType::kLeakyRelu: case lite_api::ActivationType::kLeakyRelu:
#ifdef __aarch64__ scale_ptr[0] = act_param.Leaky_relu_alpha;
asm volatile(COMPUTE LEAKY_RELU STORE scale_ptr[1] = act_param.Leaky_relu_alpha;
: [inr0] "+r"(inr0), scale_ptr[2] = act_param.Leaky_relu_alpha;
[inr1] "+r"(inr1), scale_ptr[3] = act_param.Leaky_relu_alpha;
[inr2] "+r"(inr2), conv_3x3s1_depthwise_fp32_leakyRelu(i_data,
[inr3] "+r"(inr3), o_data,
[out] "+r"(out0) bs,
: [w0] "w"(w0), oc,
[w1] "w"(w1), oh,
[w2] "w"(w2), ow,
[w3] "w"(w3), ic,
[w4] "w"(w4), ih,
[w5] "w"(w5), win,
[w6] "w"(w6), weights,
[w7] "w"(w7), bias,
[w8] "w"(w8), relu_ptr,
[vbias] "w"(vbias), six_ptr,
[outl] "r"(outl_ptr), scale_ptr,
[flag_mask] "r"(flag_mask) param,
: "cc", ctx);
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
#if 1 // def LITE_WITH_ARM_CLANG
#else
asm volatile(COMPUTE LEAKY_RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4",
"r5");
#endif
#endif
break; break;
default: default:
LOG(FATAL) << "this act_type: " LOG(FATAL) << "this act_type: "
...@@ -728,108 +589,289 @@ void act_switch_3x3s1(const float* inr0, ...@@ -728,108 +589,289 @@ void act_switch_3x3s1(const float* inr0,
<< " fuse not support"; << " fuse not support";
} }
} else { } else {
#ifdef __aarch64__ conv_3x3s1_depthwise_fp32_bias(i_data,
asm volatile(COMPUTE STORE o_data,
: [inr0] "+r"(inr0), bs,
[inr1] "+r"(inr1), oc,
[inr2] "+r"(inr2), oh,
[inr3] "+r"(inr3), ow,
[out] "+r"(out0) ic,
: [w0] "w"(w0), ih,
[w1] "w"(w1), win,
[w2] "w"(w2), weights,
[w3] "w"(w3), bias,
[w4] "w"(w4), relu_ptr,
[w5] "w"(w5), six_ptr,
[w6] "w"(w6), scale_ptr,
[w7] "w"(w7), param,
[w8] "w"(w8), ctx);
[vbias] "w"(vbias), }
[outl] "r"(outl_ptr), }
[flag_mask] "r"(flag_mask)
: "cc", void conv_3x3s1_depthwise_fp32_bias(const float* i_data,
"memory", float* o_data,
"v0", int bs,
"v1", int oc,
"v2", int oh,
"v3", int ow,
"v4", int ic,
"v5", int ih,
"v6", int win,
"v7", const float* weights,
"v8", const float* bias,
"v9", float* relu_ptr,
"v10", float* six_ptr,
"v11", float* scale_ptr,
"v15", const operators::ConvParam& param,
"v16", ARMContext* ctx) {
"v17", int threads = ctx->threads();
"v18",
"v19", auto paddings = *param.paddings;
"v20", const int pad_h = paddings[0];
"v21", const int pad_w = paddings[2];
"v22",
"x0", const int out_c_block = 4;
"x1", const int out_h_kernel = 2;
"x2", const int out_w_kernel = 4;
"x3", const int win_ext = ow + 2;
"x4", const int ow_round = ROUNDUP(ow, 4);
"x5", const int win_round = ROUNDUP(win_ext, 4);
"x6", const int hin_round = oh + 2;
"x7"); const int prein_size = win_round * hin_round * out_c_block;
#else auto workspace_size =
#if 1 // def LITE_WITH_ARM_CLANG threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_bias = param.bias != nullptr;
/// get workspace
LOG(INFO) << "conv_3x3s1_depthwise_fp32_bias: ";
float* ptr_zero = ctx->workspace_data<float>();
memset(ptr_zero, 0, sizeof(float) * win_round);
float* ptr_write = ptr_zero + win_round;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int ws = -pad_w;
int we = ws + win_round;
int hs = -pad_h;
int he = hs + hin_round;
int w_loop = ow_round / 4;
auto remain = w_loop * 4 - ow;
bool flag_remain = remain > 0;
remain = 4 - remain;
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc; c += out_c_block) {
#ifdef ARM_WITH_OMP
float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size;
#else #else
asm volatile(COMPUTE STORE float* pre_din = ptr_write + ow_round;
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4",
"r5");
#endif #endif
/// const array size
float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT
prepack_input_nxwc4_dw(
din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero);
const float* weight_c = weights + c * 9; // kernel_w * kernel_h
float* dout_c00 = dout_batch + c * size_out_channel;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
float32x4_t vbias = vld1q_f32(bias_local);
#ifdef __aarch64__
float32x4_t w0 = vld1q_f32(weight_c); // w0, v23
float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24
float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25
float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26
float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27
float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28
float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29
float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30
float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31
#endif
for (int h = 0; h < oh; h += out_h_kernel) {
float* outc00 = dout_c00 + h * ow;
float* outc01 = outc00 + ow;
float* outc10 = outc00 + size_out_channel;
float* outc11 = outc10 + ow;
float* outc20 = outc10 + size_out_channel;
float* outc21 = outc20 + ow;
float* outc30 = outc20 + size_out_channel;
float* outc31 = outc30 + ow;
const float* inr0 = pre_din + h * row_len;
const float* inr1 = inr0 + row_len;
const float* inr2 = inr1 + row_len;
const float* inr3 = inr2 + row_len;
if (c + out_c_block > oc) {
switch (c + out_c_block - oc) {
case 3: // outc10-outc30 is ptr_write and extra
outc10 = ptr_write;
outc11 = ptr_write;
case 2: // outc20-outc30 is ptr_write and extra
outc20 = ptr_write;
outc21 = ptr_write;
case 1: // outc30 is ptr_write and extra
outc30 = ptr_write;
outc31 = ptr_write;
default:
break;
}
}
if (h + out_h_kernel > oh) {
outc01 = ptr_write;
outc11 = ptr_write;
outc21 = ptr_write;
outc31 = ptr_write;
}
float* outl[] = {outc00,
outc10,
outc20,
outc30,
outc01,
outc11,
outc21,
outc31,
reinterpret_cast<float*>(bias_local),
reinterpret_cast<float*>(relu_ptr),
reinterpret_cast<float*>(six_ptr),
reinterpret_cast<float*>(scale_ptr)};
void* outl_ptr = reinterpret_cast<void*>(outl);
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
#ifdef __aarch64__
asm volatile(COMPUTE STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[inr3] "+r"(inr3),
[out] "+r"(out0)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
asm volatile(COMPUTE STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4");
#endif #endif
outl[0] += 4;
outl[1] += 4;
outl[2] += 4;
outl[3] += 4;
outl[4] += 4;
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
if (flag_mask) {
memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float));
memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float));
memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float));
memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float));
memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float));
memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float));
}
}
}
}
} }
} }
void conv_3x3s1_depthwise_fp32(const float* i_data,
float* o_data, void conv_3x3s1_depthwise_fp32_relu(const float* i_data,
int bs, float* o_data,
int oc, int bs,
int oh, int oc,
int ow, int oh,
int ic, int ow,
int ih, int ic,
int win, int ih,
const float* weights, int win,
const float* bias, const float* weights,
const operators::ConvParam& param, const float* bias,
const operators::ActivationParam act_param, float* relu_ptr,
ARMContext* ctx) { float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx) {
int threads = ctx->threads(); int threads = ctx->threads();
auto paddings = *param.paddings; auto paddings = *param.paddings;
...@@ -869,31 +911,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -869,31 +911,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
remain = remain > 0 ? remain : 0; remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block; int row_len = win_round * out_c_block;
float six_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f};
float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
break;
case lite_api::ActivationType::kRelu6:
six_ptr[0] = act_param.Relu_clipped_coef;
six_ptr[1] = act_param.Relu_clipped_coef;
six_ptr[2] = act_param.Relu_clipped_coef;
six_ptr[3] = act_param.Relu_clipped_coef;
break;
case lite_api::ActivationType::kLeakyRelu:
scale_ptr[0] = act_param.Leaky_relu_alpha;
scale_ptr[1] = act_param.Leaky_relu_alpha;
scale_ptr[2] = act_param.Leaky_relu_alpha;
scale_ptr[3] = act_param.Leaky_relu_alpha;
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
for (int n = 0; n < bs; ++n) { for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel; const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel; float* dout_batch = o_data + n * oc * size_out_channel;
...@@ -944,13 +961,13 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -944,13 +961,13 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
const float* inr3 = inr2 + row_len; const float* inr3 = inr2 + row_len;
if (c + out_c_block > oc) { if (c + out_c_block > oc) {
switch (c + out_c_block - oc) { switch (c + out_c_block - oc) {
case 3: case 3: // outc10-outc30 is ptr_write and extra
outc10 = ptr_write; outc10 = ptr_write;
outc11 = ptr_write; outc11 = ptr_write;
case 2: case 2: // outc20-outc30 is ptr_write and extra
outc20 = ptr_write; outc20 = ptr_write;
outc21 = ptr_write; outc21 = ptr_write;
case 1: case 1: // outc30 is ptr_write and extra
outc30 = ptr_write; outc30 = ptr_write;
outc31 = ptr_write; outc31 = ptr_write;
default: default:
...@@ -981,48 +998,86 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -981,48 +998,86 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
bool flag_mask = (w == w_loop - 1) && flag_remain; bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out; float* out0 = pre_out;
#ifdef __aarch64__ #ifdef __aarch64__
act_switch_3x3s1(inr0, asm volatile(COMPUTE RELU STORE
inr1, : [inr0] "+r"(inr0),
inr2, [inr1] "+r"(inr1),
inr3, [inr2] "+r"(inr2),
out0, [inr3] "+r"(inr3),
weight_c, [out] "+r"(out0)
flag_mask, : [w0] "w"(w0),
outl_ptr, [w1] "w"(w1),
w0, [w2] "w"(w2),
w1, [w3] "w"(w3),
w2, [w4] "w"(w4),
w3, [w5] "w"(w5),
w4, [w6] "w"(w6),
w5, [w7] "w"(w7),
w6, [w8] "w"(w8),
w7, [vbias] "w"(vbias),
w8, [outl] "r"(outl_ptr),
vbias, [flag_mask] "r"(flag_mask)
act_param); : "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else #else
#if 1 // def LITE_WITH_ARM_CLANG asm volatile(COMPUTE RELU STORE
#else : [r0] "+r"(inr0),
act_switch_3x3s1(inr0, [r1] "+r"(inr1),
inr1, [r2] "+r"(inr2),
inr2, [r3] "+r"(inr3),
inr3, [out0] "+r"(out0),
out0, [wc0] "+r"(weight_c)
weight_c, : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
flag_mask, : "cc",
outl_ptr, "memory",
vbias, "q0",
vbias, "q1",
vbias, "q2",
vbias, "q3",
vbias, "q4",
vbias, "q5",
vbias, "q6",
vbias, "q7",
vbias, "q8",
vbias, "q9",
act_param); "q10",
#endif "q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4");
#endif #endif
outl[0] += 4; outl[0] += 4;
outl[1] += 4; outl[1] += 4;
...@@ -1032,10 +1087,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -1032,10 +1087,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outl[5] += 4; outl[5] += 4;
outl[6] += 4; outl[6] += 4;
outl[7] += 4; outl[7] += 4;
inr0 += 16;
inr1 += 16;
inr2 += 16;
inr3 += 16;
if (flag_mask) { if (flag_mask) {
memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
...@@ -1052,6 +1103,499 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -1052,6 +1103,499 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
} }
} }
void conv_3x3s1_depthwise_fp32_relu6(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx) {
int threads = ctx->threads();
auto paddings = *param.paddings;
const int pad_h = paddings[0];
const int pad_w = paddings[2];
const int out_c_block = 4;
const int out_h_kernel = 2;
const int out_w_kernel = 4;
const int win_ext = ow + 2;
const int ow_round = ROUNDUP(ow, 4);
const int win_round = ROUNDUP(win_ext, 4);
const int hin_round = oh + 2;
const int prein_size = win_round * hin_round * out_c_block;
auto workspace_size =
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_bias = param.bias != nullptr;
/// get workspace
float* ptr_zero = ctx->workspace_data<float>();
memset(ptr_zero, 0, sizeof(float) * win_round);
float* ptr_write = ptr_zero + win_round;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int ws = -pad_w;
int we = ws + win_round;
int hs = -pad_h;
int he = hs + hin_round;
int w_loop = ow_round / 4;
auto remain = w_loop * 4 - ow;
bool flag_remain = remain > 0;
remain = 4 - remain;
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc; c += out_c_block) {
#ifdef ARM_WITH_OMP
float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size;
#else
float* pre_din = ptr_write + ow_round;
#endif
/// const array size
float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT
prepack_input_nxwc4_dw(
din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero);
const float* weight_c = weights + c * 9; // kernel_w * kernel_h
float* dout_c00 = dout_batch + c * size_out_channel;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
float32x4_t vbias = vld1q_f32(bias_local);
#ifdef __aarch64__
float32x4_t w0 = vld1q_f32(weight_c); // w0, v23
float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24
float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25
float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26
float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27
float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28
float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29
float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30
float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31
#endif
for (int h = 0; h < oh; h += out_h_kernel) {
float* outc00 = dout_c00 + h * ow;
float* outc01 = outc00 + ow;
float* outc10 = outc00 + size_out_channel;
float* outc11 = outc10 + ow;
float* outc20 = outc10 + size_out_channel;
float* outc21 = outc20 + ow;
float* outc30 = outc20 + size_out_channel;
float* outc31 = outc30 + ow;
const float* inr0 = pre_din + h * row_len;
const float* inr1 = inr0 + row_len;
const float* inr2 = inr1 + row_len;
const float* inr3 = inr2 + row_len;
if (c + out_c_block > oc) {
switch (c + out_c_block - oc) {
case 3: // outc10-outc30 is ptr_write and extra
outc10 = ptr_write;
outc11 = ptr_write;
case 2: // outc20-outc30 is ptr_write and extra
outc20 = ptr_write;
outc21 = ptr_write;
case 1: // outc30 is ptr_write and extra
outc30 = ptr_write;
outc31 = ptr_write;
default:
break;
}
}
if (h + out_h_kernel > oh) {
outc01 = ptr_write;
outc11 = ptr_write;
outc21 = ptr_write;
outc31 = ptr_write;
}
float* outl[] = {outc00,
outc10,
outc20,
outc30,
outc01,
outc11,
outc21,
outc31,
reinterpret_cast<float*>(bias_local),
reinterpret_cast<float*>(relu_ptr),
reinterpret_cast<float*>(six_ptr),
reinterpret_cast<float*>(scale_ptr)};
void* outl_ptr = reinterpret_cast<void*>(outl);
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
#ifdef __aarch64__
asm volatile(COMPUTE RELU RELU6 STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[inr3] "+r"(inr3),
[out] "+r"(out0)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
asm volatile(COMPUTE RELU RELU6 STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4");
#endif
outl[0] += 4;
outl[1] += 4;
outl[2] += 4;
outl[3] += 4;
outl[4] += 4;
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
if (flag_mask) {
memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float));
memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float));
memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float));
memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float));
memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float));
memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float));
}
}
}
}
}
}
void conv_3x3s1_depthwise_fp32_leakyRelu(const float* i_data,
float* o_data,
int bs,
int oc,
int oh,
int ow,
int ic,
int ih,
int win,
const float* weights,
const float* bias,
float* relu_ptr,
float* six_ptr,
float* scale_ptr,
const operators::ConvParam& param,
ARMContext* ctx) {
int threads = ctx->threads();
auto paddings = *param.paddings;
const int pad_h = paddings[0];
const int pad_w = paddings[2];
const int out_c_block = 4;
const int out_h_kernel = 2;
const int out_w_kernel = 4;
const int win_ext = ow + 2;
const int ow_round = ROUNDUP(ow, 4);
const int win_round = ROUNDUP(win_ext, 4);
const int hin_round = oh + 2;
const int prein_size = win_round * hin_round * out_c_block;
auto workspace_size =
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_bias = param.bias != nullptr;
/// get workspace
float* ptr_zero = ctx->workspace_data<float>();
memset(ptr_zero, 0, sizeof(float) * win_round);
float* ptr_write = ptr_zero + win_round;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int ws = -pad_w;
int we = ws + win_round;
int hs = -pad_h;
int he = hs + hin_round;
int w_loop = ow_round / 4;
auto remain = w_loop * 4 - ow;
bool flag_remain = remain > 0;
remain = 4 - remain;
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc; c += out_c_block) {
#ifdef ARM_WITH_OMP
float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size;
#else
float* pre_din = ptr_write + ow_round;
#endif
/// const array size
float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT
prepack_input_nxwc4_dw(
din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero);
const float* weight_c = weights + c * 9; // kernel_w * kernel_h
float* dout_c00 = dout_batch + c * size_out_channel;
float bias_local[4] = {0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
}
float32x4_t vbias = vld1q_f32(bias_local);
#ifdef __aarch64__
float32x4_t w0 = vld1q_f32(weight_c); // w0, v23
float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24
float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25
float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26
float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27
float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28
float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29
float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30
float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31
#endif
for (int h = 0; h < oh; h += out_h_kernel) {
float* outc00 = dout_c00 + h * ow;
float* outc01 = outc00 + ow;
float* outc10 = outc00 + size_out_channel;
float* outc11 = outc10 + ow;
float* outc20 = outc10 + size_out_channel;
float* outc21 = outc20 + ow;
float* outc30 = outc20 + size_out_channel;
float* outc31 = outc30 + ow;
const float* inr0 = pre_din + h * row_len;
const float* inr1 = inr0 + row_len;
const float* inr2 = inr1 + row_len;
const float* inr3 = inr2 + row_len;
if (c + out_c_block > oc) {
switch (c + out_c_block - oc) {
case 3: // outc10-outc30 is ptr_write and extra
outc10 = ptr_write;
outc11 = ptr_write;
case 2: // outc20-outc30 is ptr_write and extra
outc20 = ptr_write;
outc21 = ptr_write;
case 1: // outc30 is ptr_write and extra
outc30 = ptr_write;
outc31 = ptr_write;
default:
break;
}
}
if (h + out_h_kernel > oh) {
outc01 = ptr_write;
outc11 = ptr_write;
outc21 = ptr_write;
outc31 = ptr_write;
}
float* outl[] = {outc00,
outc10,
outc20,
outc30,
outc01,
outc11,
outc21,
outc31,
reinterpret_cast<float*>(bias_local),
reinterpret_cast<float*>(relu_ptr),
reinterpret_cast<float*>(six_ptr),
reinterpret_cast<float*>(scale_ptr)};
void* outl_ptr = reinterpret_cast<void*>(outl);
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
#ifdef __aarch64__
asm volatile(COMPUTE LEAKY_RELU STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[inr3] "+r"(inr3),
[out] "+r"(out0)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
asm volatile(COMPUTE LEAKY_RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4");
#endif
outl[0] += 4;
outl[1] += 4;
outl[2] += 4;
outl[3] += 4;
outl[4] += 4;
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
if (flag_mask) {
memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float));
memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float));
memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float));
memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float));
memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float));
memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float));
}
}
}
}
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -620,8 +620,10 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -620,8 +620,10 @@ void conv_depthwise_3x3_fp32(const void* din,
int pad = pad_w; int pad = pad_w;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
bool ch_four = ch_in <= 4 * w_in;
if (stride == 1) { if (stride == 1) {
if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] if (ch_four && pads_less && (pad_h == pad_w) &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din), conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
...@@ -638,7 +640,6 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -638,7 +640,6 @@ void conv_depthwise_3x3_fp32(const void* din,
act_param, act_param,
ctx); ctx);
} else { } else {
#ifdef __aarch64__
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din), conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
...@@ -653,30 +654,10 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -653,30 +654,10 @@ void conv_depthwise_3x3_fp32(const void* din,
param, param,
act_param, act_param,
ctx); ctx);
#else
#ifdef LITE_WITH_ARM_CLANG
LOG(FATAL) << "fp32 depthwise conv3x3s1px doesnot support in v7-clang, "
"this can run in basic";
#else
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
h_out,
w_out,
ch_in,
h_in,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
#endif
#endif
} }
} else if (stride == 2) { } else if (stride == 2) {
if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] if (ch_four && pads_less && pad_h == pad_w &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din), conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
......
...@@ -59,12 +59,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -59,12 +59,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2); bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2);
bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2); bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2);
#ifdef __aarch64__
#else
bool flag =
(stride == 1 && (paddings[0] > 1 || paddings[2] > 1)) ? false : true;
flag_dw_3x3 = flag_dw_3x3 && flag;
#endif
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
/// select conv impl /// select conv impl
......
...@@ -28,11 +28,15 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -28,11 +28,15 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto kw = w_dims[3]; auto kw = w_dims[3];
auto channel = w_dims[0];
auto hin = param.x->dims()[2];
auto win = param.x->dims()[3];
auto paddings = *param.paddings; auto paddings = *param.paddings;
bool ch_four = channel <= 4 * win;
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (pads_less && paddings[0] == paddings[2] && if (ch_four && pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) { (paddings[0] == 0 || paddings[0] == 1)) {
flag_trans_weights_ = false; flag_trans_weights_ = false;
} else { } else {
...@@ -398,6 +402,14 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -398,6 +402,14 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
w_scale_.data()); w_scale_.data());
} }
#ifdef LITE_WITH_PROFILE
template <>
void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -39,6 +39,7 @@ DEFINE_int32(power_mode, ...@@ -39,6 +39,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(repeats, 1, "repeats times");
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
// sgemm_test wiil not be operated except that it's // sgemm_test wiil not be operated except that it's
// on arm backend. // on arm backend.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册