diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index 244467d62492bc3017ebdb6144b49ccb9fcd30c1..88c449e6a9d8b8078802e90dded5db1162459d3f 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -127,8 +127,10 @@ if (NOT HAS_ARM_MATH_LIB_DIR) anchor_generator.cc split_merge_lod_tenosr.cc reduce_prod.cc + reduce_sum.cc lstm.cc clip.cc pixel_shuffle.cc + scatter.cc DEPS ${lite_kernel_deps} context tensor) endif() diff --git a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc index c998ddc3a34c2f6194a5156b7d04b7a9db3fbcef..b4539db98c3ffb1a143c38dd3c4dd9e9924bd63e 100644 --- a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc @@ -25,6 +25,73 @@ namespace paddle { namespace lite { namespace arm { namespace math { +void conv_3x3s1_depthwise_fp32_bias(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx); + +void conv_3x3s1_depthwise_fp32_relu(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx); + +void conv_3x3s1_depthwise_fp32_relu6(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx); + +void conv_3x3s1_depthwise_fp32_leakyRelu(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx); // clang-format off #ifdef __aarch64__ #define COMPUTE \ @@ -335,7 +402,6 @@ namespace math { "ldr r0, [%[outl]] @ load outc00 to r0\n" \ "vmla.f32 q12, q5, q0 @ w8 * inr32\n" \ "vmla.f32 q13, q5, q1 @ w8 * inr33\n" \ - "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" \ "vmla.f32 q14, q5, q2 @ w8 * inr34\n" \ "vmla.f32 q15, q5, q3 @ w8 * inr35\n" \ "ldr r1, [%[outl], #4] @ load outc10 to r1\n" \ @@ -406,7 +472,6 @@ namespace math { "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \ "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \ "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \ - "ldr r5, [%[outl], #20] @ load outc11 to r5\n" \ "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \ "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \ "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \ @@ -417,12 +482,13 @@ namespace math { "vst1.32 {d18-d19}, [r1] @ save outc10\n" \ "vst1.32 {d20-d21}, [r2] @ save outc20\n" \ "vst1.32 {d22-d23}, [r3] @ save outc30\n" \ + "ldr r0, [%[outl], #20] @ load outc11 to r5\n" \ + "ldr r1, [%[outl], #24] @ load outc21 to r0\n" \ + "ldr r2, [%[outl], #28] @ load outc31 to r1\n" \ "vst1.32 {d24-d25}, [r4] @ save outc01\n" \ - "vst1.32 {d26-d27}, [r5] @ save outc11\n" \ - "ldr r0, [%[outl], #24] @ load outc21 to r0\n" \ - "ldr r1, [%[outl], #28] @ load outc31 to r1\n" \ - "vst1.32 {d28-d29}, [r0] @ save outc21\n" \ - "vst1.32 {d30-d31}, [r1] @ save outc31\n" \ + "vst1.32 {d26-d27}, [r0] @ save outc11\n" \ + "vst1.32 {d28-d29}, [r1] @ save outc21\n" \ + "vst1.32 {d30-d31}, [r2] @ save outc31\n" \ "b 3f @ branch end\n" \ "2: \n" \ "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \ @@ -436,291 +502,86 @@ namespace math { "3: \n" #endif // clang-format on -void act_switch_3x3s1(const float* inr0, - const float* inr1, - const float* inr2, - const float* inr3, - float* out0, - const float* weight_c, - float flag_mask, - void* outl_ptr, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - float32x4_t w5, - float32x4_t w6, - float32x4_t w7, - float32x4_t w8, - float32x4_t vbias, - const operators::ActivationParam act_param) { - bool has_active = act_param.has_active; - if (has_active) { +void conv_3x3s1_depthwise_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + const operators::ActivationParam act_param, + ARMContext* ctx) { + float six_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f}; + float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { switch (act_param.active_type) { case lite_api::ActivationType::kRelu: -#ifdef __aarch64__ - asm volatile(COMPUTE RELU STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [out] "+r"(out0) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8), - [vbias] "w"(vbias), - [outl] "r"(outl_ptr), - [flag_mask] "r"(flag_mask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7"); -#else -#if 1 // def LITE_WITH_ARM_CLANG -#else - asm volatile(COMPUTE RELU STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [out0] "+r"(out0), - [wc0] "+r"(weight_c) - : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "r0", - "r1", - "r2", - "r3", - "r4", - "r5"); -#endif -#endif + conv_3x3s1_depthwise_fp32_relu(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + relu_ptr, + six_ptr, + scale_ptr, + param, + ctx); break; case lite_api::ActivationType::kRelu6: -#ifdef __aarch64__ - asm volatile(COMPUTE RELU RELU6 STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [out] "+r"(out0) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8), - [vbias] "w"(vbias), - [outl] "r"(outl_ptr), - [flag_mask] "r"(flag_mask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7"); -#else -#if 1 // def LITE_WITH_ARM_CLANG -#else - asm volatile(COMPUTE RELU RELU6 STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [out0] "+r"(out0), - [wc0] "+r"(weight_c) - : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "r0", - "r1", - "r2", - "r3", - "r4", - "r5"); -#endif -#endif + six_ptr[0] = act_param.Relu_clipped_coef; + six_ptr[1] = act_param.Relu_clipped_coef; + six_ptr[2] = act_param.Relu_clipped_coef; + six_ptr[3] = act_param.Relu_clipped_coef; + conv_3x3s1_depthwise_fp32_relu6(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + relu_ptr, + six_ptr, + scale_ptr, + param, + ctx); break; case lite_api::ActivationType::kLeakyRelu: -#ifdef __aarch64__ - asm volatile(COMPUTE LEAKY_RELU STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [out] "+r"(out0) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8), - [vbias] "w"(vbias), - [outl] "r"(outl_ptr), - [flag_mask] "r"(flag_mask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7"); -#else -#if 1 // def LITE_WITH_ARM_CLANG -#else - asm volatile(COMPUTE LEAKY_RELU STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [out0] "+r"(out0), - [wc0] "+r"(weight_c) - : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "r0", - "r1", - "r2", - "r3", - "r4", - "r5"); -#endif -#endif + scale_ptr[0] = act_param.Leaky_relu_alpha; + scale_ptr[1] = act_param.Leaky_relu_alpha; + scale_ptr[2] = act_param.Leaky_relu_alpha; + scale_ptr[3] = act_param.Leaky_relu_alpha; + conv_3x3s1_depthwise_fp32_leakyRelu(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + relu_ptr, + six_ptr, + scale_ptr, + param, + ctx); break; default: LOG(FATAL) << "this act_type: " @@ -728,108 +589,289 @@ void act_switch_3x3s1(const float* inr0, << " fuse not support"; } } else { -#ifdef __aarch64__ - asm volatile(COMPUTE STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [out] "+r"(out0) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8), - [vbias] "w"(vbias), - [outl] "r"(outl_ptr), - [flag_mask] "r"(flag_mask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7"); -#else -#if 1 // def LITE_WITH_ARM_CLANG + conv_3x3s1_depthwise_fp32_bias(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + relu_ptr, + six_ptr, + scale_ptr, + param, + ctx); + } +} + +void conv_3x3s1_depthwise_fp32_bias(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx) { + int threads = ctx->threads(); + + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + + const int out_c_block = 4; + const int out_h_kernel = 2; + const int out_w_kernel = 4; + const int win_ext = ow + 2; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh + 2; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = + threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + LOG(INFO) << "conv_3x3s1_depthwise_fp32_bias: "; + float* ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; #else - asm volatile(COMPUTE STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [out0] "+r"(out0), - [wc0] "+r"(weight_c) - : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "r0", - "r1", - "r2", - "r3", - "r4", - "r5"); + float* pre_din = ptr_write + ow_round; #endif + /// const array size + float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + float32x4_t vbias = vld1q_f32(bias_local); +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc00 = dout_c00 + h * ow; + float* outc01 = outc00 + ow; + float* outc10 = outc00 + size_out_channel; + float* outc11 = outc10 + ow; + float* outc20 = outc10 + size_out_channel; + float* outc21 = outc20 + ow; + float* outc30 = outc20 + size_out_channel; + float* outc31 = outc30 + ow; + const float* inr0 = pre_din + h * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: // outc10-outc30 is ptr_write and extra + outc10 = ptr_write; + outc11 = ptr_write; + case 2: // outc20-outc30 is ptr_write and extra + outc20 = ptr_write; + outc21 = ptr_write; + case 1: // outc30 is ptr_write and extra + outc30 = ptr_write; + outc31 = ptr_write; + default: + break; + } + } + if (h + out_h_kernel > oh) { + outc01 = ptr_write; + outc11 = ptr_write; + outc21 = ptr_write; + outc31 = ptr_write; + } + + float* outl[] = {outc00, + outc10, + outc20, + outc30, + outc01, + outc11, + outc21, + outc31, + reinterpret_cast(bias_local), + reinterpret_cast(relu_ptr), + reinterpret_cast(six_ptr), + reinterpret_cast(scale_ptr)}; + void* outl_ptr = reinterpret_cast(outl); + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + float* out0 = pre_out; +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4"); #endif + outl[0] += 4; + outl[1] += 4; + outl[2] += 4; + outl[3] += 4; + outl[4] += 4; + outl[5] += 4; + outl[6] += 4; + outl[7] += 4; + if (flag_mask) { + memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); + memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); + memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); + memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); + memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); + memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); + memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); + memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); + } + } + } + } } } -void conv_3x3s1_depthwise_fp32(const float* i_data, - float* o_data, - int bs, - int oc, - int oh, - int ow, - int ic, - int ih, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - const operators::ActivationParam act_param, - ARMContext* ctx) { + +void conv_3x3s1_depthwise_fp32_relu(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx) { int threads = ctx->threads(); auto paddings = *param.paddings; @@ -869,31 +911,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, remain = remain > 0 ? remain : 0; int row_len = win_round * out_c_block; - float six_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f}; - float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - if (act_param.has_active) { - switch (act_param.active_type) { - case lite_api::ActivationType::kRelu: - break; - case lite_api::ActivationType::kRelu6: - six_ptr[0] = act_param.Relu_clipped_coef; - six_ptr[1] = act_param.Relu_clipped_coef; - six_ptr[2] = act_param.Relu_clipped_coef; - six_ptr[3] = act_param.Relu_clipped_coef; - break; - case lite_api::ActivationType::kLeakyRelu: - scale_ptr[0] = act_param.Leaky_relu_alpha; - scale_ptr[1] = act_param.Leaky_relu_alpha; - scale_ptr[2] = act_param.Leaky_relu_alpha; - scale_ptr[3] = act_param.Leaky_relu_alpha; - break; - default: - LOG(FATAL) << "this act_type: " - << static_cast(act_param.active_type) - << " fuse not support"; - } - } for (int n = 0; n < bs; ++n) { const float* din_batch = i_data + n * ic * size_in_channel; float* dout_batch = o_data + n * oc * size_out_channel; @@ -944,13 +961,13 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, const float* inr3 = inr2 + row_len; if (c + out_c_block > oc) { switch (c + out_c_block - oc) { - case 3: + case 3: // outc10-outc30 is ptr_write and extra outc10 = ptr_write; outc11 = ptr_write; - case 2: + case 2: // outc20-outc30 is ptr_write and extra outc20 = ptr_write; outc21 = ptr_write; - case 1: + case 1: // outc30 is ptr_write and extra outc30 = ptr_write; outc31 = ptr_write; default: @@ -981,48 +998,86 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, bool flag_mask = (w == w_loop - 1) && flag_remain; float* out0 = pre_out; #ifdef __aarch64__ - act_switch_3x3s1(inr0, - inr1, - inr2, - inr3, - out0, - weight_c, - flag_mask, - outl_ptr, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - w7, - w8, - vbias, - act_param); + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); #else -#if 1 // def LITE_WITH_ARM_CLANG -#else - act_switch_3x3s1(inr0, - inr1, - inr2, - inr3, - out0, - weight_c, - flag_mask, - outl_ptr, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - act_param); -#endif + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4"); #endif outl[0] += 4; outl[1] += 4; @@ -1032,10 +1087,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, outl[5] += 4; outl[6] += 4; outl[7] += 4; - inr0 += 16; - inr1 += 16; - inr2 += 16; - inr3 += 16; if (flag_mask) { memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); @@ -1052,6 +1103,499 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, } } +void conv_3x3s1_depthwise_fp32_relu6(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx) { + int threads = ctx->threads(); + + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + + const int out_c_block = 4; + const int out_h_kernel = 2; + const int out_w_kernel = 4; + const int win_ext = ow + 2; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh + 2; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = + threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + float* ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + float32x4_t vbias = vld1q_f32(bias_local); +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc00 = dout_c00 + h * ow; + float* outc01 = outc00 + ow; + float* outc10 = outc00 + size_out_channel; + float* outc11 = outc10 + ow; + float* outc20 = outc10 + size_out_channel; + float* outc21 = outc20 + ow; + float* outc30 = outc20 + size_out_channel; + float* outc31 = outc30 + ow; + const float* inr0 = pre_din + h * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: // outc10-outc30 is ptr_write and extra + outc10 = ptr_write; + outc11 = ptr_write; + case 2: // outc20-outc30 is ptr_write and extra + outc20 = ptr_write; + outc21 = ptr_write; + case 1: // outc30 is ptr_write and extra + outc30 = ptr_write; + outc31 = ptr_write; + default: + break; + } + } + if (h + out_h_kernel > oh) { + outc01 = ptr_write; + outc11 = ptr_write; + outc21 = ptr_write; + outc31 = ptr_write; + } + + float* outl[] = {outc00, + outc10, + outc20, + outc30, + outc01, + outc11, + outc21, + outc31, + reinterpret_cast(bias_local), + reinterpret_cast(relu_ptr), + reinterpret_cast(six_ptr), + reinterpret_cast(scale_ptr)}; + void* outl_ptr = reinterpret_cast(outl); + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + float* out0 = pre_out; +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4"); +#endif + outl[0] += 4; + outl[1] += 4; + outl[2] += 4; + outl[3] += 4; + outl[4] += 4; + outl[5] += 4; + outl[6] += 4; + outl[7] += 4; + if (flag_mask) { + memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); + memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); + memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); + memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); + memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); + memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); + memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); + memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); + } + } + } + } + } +} + +void conv_3x3s1_depthwise_fp32_leakyRelu(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx) { + int threads = ctx->threads(); + + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + + const int out_c_block = 4; + const int out_h_kernel = 2; + const int out_w_kernel = 4; + const int win_ext = ow + 2; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh + 2; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = + threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + float* ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + float32x4_t vbias = vld1q_f32(bias_local); +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc00 = dout_c00 + h * ow; + float* outc01 = outc00 + ow; + float* outc10 = outc00 + size_out_channel; + float* outc11 = outc10 + ow; + float* outc20 = outc10 + size_out_channel; + float* outc21 = outc20 + ow; + float* outc30 = outc20 + size_out_channel; + float* outc31 = outc30 + ow; + const float* inr0 = pre_din + h * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: // outc10-outc30 is ptr_write and extra + outc10 = ptr_write; + outc11 = ptr_write; + case 2: // outc20-outc30 is ptr_write and extra + outc20 = ptr_write; + outc21 = ptr_write; + case 1: // outc30 is ptr_write and extra + outc30 = ptr_write; + outc31 = ptr_write; + default: + break; + } + } + if (h + out_h_kernel > oh) { + outc01 = ptr_write; + outc11 = ptr_write; + outc21 = ptr_write; + outc31 = ptr_write; + } + + float* outl[] = {outc00, + outc10, + outc20, + outc30, + outc01, + outc11, + outc21, + outc31, + reinterpret_cast(bias_local), + reinterpret_cast(relu_ptr), + reinterpret_cast(six_ptr), + reinterpret_cast(scale_ptr)}; + void* outl_ptr = reinterpret_cast(outl); + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + float* out0 = pre_out; +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4"); +#endif + outl[0] += 4; + outl[1] += 4; + outl[2] += 4; + outl[3] += 4; + outl[4] += 4; + outl[5] += 4; + outl[6] += 4; + outl[7] += 4; + if (flag_mask) { + memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); + memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); + memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); + memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); + memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); + memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); + memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); + memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); + } + } + } + } + } +} } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 2bad1f997f457429c013c11a1dce35eb43dc26da..fa2f85311b3ff4247d52505d750566ec80e47256 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -620,8 +620,10 @@ void conv_depthwise_3x3_fp32(const void* din, int pad = pad_w; bool flag_bias = param.bias != nullptr; bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); + bool ch_four = ch_in <= 4 * w_in; if (stride == 1) { - if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] + if (ch_four && pads_less && (pad_h == pad_w) && + (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s1_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -638,7 +640,6 @@ void conv_depthwise_3x3_fp32(const void* din, act_param, ctx); } else { -#ifdef __aarch64__ conv_3x3s1_depthwise_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -653,30 +654,10 @@ void conv_depthwise_3x3_fp32(const void* din, param, act_param, ctx); -#else -#ifdef LITE_WITH_ARM_CLANG - LOG(FATAL) << "fp32 depthwise conv3x3s1px doesnot support in v7-clang, " - "this can run in basic"; -#else - conv_3x3s1_depthwise_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - param, - act_param, - ctx); -#endif -#endif } } else if (stride == 2) { - if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] + if (ch_four && pads_less && pad_h == pad_w && + (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s2_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index 2e52bd1e285b7493148a5a779bffcfcfd1336722..f1ac1d63a1b40e2ead5e976e0bffe6c435a2545b 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -53,7 +53,9 @@ #include "lite/backends/arm/math/reduce_max.h" #include "lite/backends/arm/math/reduce_mean.h" #include "lite/backends/arm/math/reduce_prod.h" +#include "lite/backends/arm/math/reduce_sum.h" #include "lite/backends/arm/math/scale.h" +#include "lite/backends/arm/math/scatter.h" #include "lite/backends/arm/math/sequence_expand.h" #include "lite/backends/arm/math/sequence_pool.h" #include "lite/backends/arm/math/sequence_pool_grad.h" @@ -357,6 +359,15 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { return exp_ps(vmulq_f32(b, log_ps(a))); } +inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) { + float32x4_t vrst; + vrst[0] = a[0] + a[1]; + vrst[1] = a[2] + a[3]; + vrst[2] = b[0] + b[1]; + vrst[3] = b[2] + b[3]; + return vrst; +} + template void fill_bias_fc( T* tensor, const T* bias, int num, int channel, bool flag_relu); diff --git a/lite/backends/arm/math/reduce_sum.cc b/lite/backends/arm/math/reduce_sum.cc new file mode 100644 index 0000000000000000000000000000000000000000..b563887e8619e29e40d85699b6979713aae8c0a2 --- /dev/null +++ b/lite/backends/arm/math/reduce_sum.cc @@ -0,0 +1,385 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/arm/math/reduce_sum.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void reduce_sum_n(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int chw_size = channel_in * height_in * width_in; + if (num_in == 1) { + memcpy(dst, src, sizeof(float) * chw_size); + } else { + int cnt_n = num_in >> 2; + int remain_n = num_in & 3; + int cnt_chw = chw_size >> 3; + int cnt_rem = chw_size & 7; + int stride = chw_size << 2; + int stride_c = 0; + for (int c = 0; c < cnt_chw; c++) { + float32x4_t vsum0 = vdupq_n_f32(0.f); + float32x4_t vsum1 = vdupq_n_f32(0.f); + const float* din_ptr0 = src + stride_c; + const float* din_ptr1 = din_ptr0 + chw_size; + const float* din_ptr2 = din_ptr1 + chw_size; + const float* din_ptr3 = din_ptr2 + chw_size; + for (int n = 0; n < cnt_n; n++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t vb0 = vld1q_f32(din_ptr1); + float32x4_t va1 = vld1q_f32(din_ptr0 + 4); + float32x4_t vb1 = vld1q_f32(din_ptr1 + 4); + float32x4_t vc0 = vld1q_f32(din_ptr2); + float32x4_t vd0 = vld1q_f32(din_ptr3); + float32x4_t vs00 = vaddq_f32(va0, vb0); + float32x4_t vc1 = vld1q_f32(din_ptr2 + 4); + float32x4_t vs10 = vaddq_f32(va1, vb1); + float32x4_t vd1 = vld1q_f32(din_ptr3 + 4); + float32x4_t vs01 = vaddq_f32(vc0, vd0); + vsum0 = vaddq_f32(vsum0, vs00); + float32x4_t vs11 = vaddq_f32(vc1, vd1); + vsum1 = vaddq_f32(vsum1, vs10); + din_ptr0 += stride; + din_ptr1 += stride; + vsum0 = vaddq_f32(vsum0, vs01); + din_ptr2 += stride; + din_ptr3 += stride; + vsum1 = vaddq_f32(vsum1, vs11); + } + for (int n = 0; n < remain_n; n++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t va1 = vld1q_f32(din_ptr0 + 4); + vsum0 = vaddq_f32(vsum0, va0); + din_ptr0 += chw_size; + vsum1 = vaddq_f32(vsum1, va1); + } + vst1q_f32(dst, vsum0); + dst += 4; + stride_c += 8; + vst1q_f32(dst, vsum1); + dst += 4; + } + if (cnt_rem > 3) { + float32x4_t vsum0 = vdupq_n_f32(0.f); + const float* din_ptr0 = src + stride_c; + const float* din_ptr1 = din_ptr0 + chw_size; + const float* din_ptr2 = din_ptr1 + chw_size; + const float* din_ptr3 = din_ptr2 + chw_size; + for (int n = 0; n < cnt_n; n++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t vb0 = vld1q_f32(din_ptr1); + float32x4_t vc0 = vld1q_f32(din_ptr2); + float32x4_t vd0 = vld1q_f32(din_ptr3); + float32x4_t vs00 = vaddq_f32(va0, vb0); + float32x4_t vs01 = vaddq_f32(vc0, vd0); + vsum0 = vaddq_f32(vsum0, vs00); + din_ptr0 += stride; + din_ptr1 += stride; + vsum0 = vaddq_f32(vsum0, vs01); + din_ptr2 += stride; + din_ptr3 += stride; + } + for (int n = 0; n < remain_n; n++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + din_ptr0 += chw_size; + vsum0 = vaddq_f32(vsum0, va0); + } + stride_c += 4; + vst1q_f32(dst, vsum0); + dst += 4; + cnt_rem -= 4; + } + for (int c = 0; c < cnt_rem; c++) { + const float* din_ptr0 = src + stride_c; + const float* din_ptr1 = din_ptr0 + chw_size; + const float* din_ptr2 = din_ptr1 + chw_size; + const float* din_ptr3 = din_ptr2 + chw_size; + float sum = 0.0; + for (int n = 0; n < cnt_n; n++) { + float tmp0 = din_ptr0[0] + din_ptr1[0]; + float tmp1 = din_ptr2[0] + din_ptr3[0]; + din_ptr0 += stride; + din_ptr1 += stride; + sum += tmp0; + din_ptr2 += stride; + din_ptr3 += stride; + sum += tmp1; + } + for (int n = 0; n < remain_n; n++) { + sum += din_ptr0[0]; + din_ptr0 += chw_size; + } + stride_c++; + dst[0] = sum; + dst++; + } + } +} + +template <> +void reduce_sum_c(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = hw_size * channel_in; + for (int n = 0; n < num_in; ++n) { + reduce_sum_n(src, dst, channel_in, 1, height_in, width_in); + src += chw_size; + dst += hw_size; + } +} + +template <> +void reduce_sum_h(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int nc_size = num_in * channel_in; + int hw_size = height_in * width_in; + for (int n = 0; n < nc_size; ++n) { + reduce_sum_n(src, dst, height_in, 1, 1, width_in); + src += hw_size; + dst += width_in; + } +} + +template <> +void reduce_sum_w(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int nch_size = num_in * channel_in * height_in; + int cnt_w = width_in >> 3; + int cnt_n = nch_size >> 2; + int rem_w = width_in & 7; + int rem_n = nch_size & 3; + int stride = 0; + int stride_n = width_in << 2; + for (int n = 0; n < cnt_n; n++) { + const float* din_ptr0 = src + stride; + const float* din_ptr1 = din_ptr0 + width_in; + const float* din_ptr2 = din_ptr1 + width_in; + const float* din_ptr3 = din_ptr2 + width_in; + float32x4_t vsum = vdupq_n_f32(0.f); + int tmp = rem_w; + for (int w = 0; w < cnt_w; w++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t va1 = vld1q_f32(din_ptr0 + 4); + float32x4_t vb0 = vld1q_f32(din_ptr1); + float32x4_t vb1 = vld1q_f32(din_ptr1 + 4); + float32x4_t vc0 = vld1q_f32(din_ptr2); + float32x4_t vc1 = vld1q_f32(din_ptr2 + 4); + float32x4_t vs0 = vaddq_f32(va0, va1); + float32x4_t vd0 = vld1q_f32(din_ptr3); + float32x4_t vs1 = vaddq_f32(vb0, vb1); + float32x4_t vd1 = vld1q_f32(din_ptr3 + 4); + float32x4_t vs2 = vaddq_f32(vc0, vc1); + din_ptr0 += 8; + float32x4_t vs3 = vaddq_f32(vd0, vd1); + din_ptr1 += 8; + float32x4_t vs00 = vpaddq_f32(vs0, vs1); + din_ptr2 += 8; + float32x4_t vs01 = vpaddq_f32(vs2, vs3); + din_ptr3 += 8; + float32x4_t vs = vpaddq_f32(vs00, vs01); + vsum = vaddq_f32(vs, vsum); + } + if (tmp > 3) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t vb0 = vld1q_f32(din_ptr1); + float32x4_t vc0 = vld1q_f32(din_ptr2); + float32x4_t vd0 = vld1q_f32(din_ptr3); + din_ptr0 += 4; + din_ptr1 += 4; + float32x4_t vs00 = vpaddq_f32(va0, vb0); + float32x4_t vs01 = vpaddq_f32(vc0, vd0); + din_ptr2 += 4; + din_ptr3 += 4; + float32x4_t vs = vpaddq_f32(vs00, vs01); + vsum = vaddq_f32(vs, vsum); + tmp -= 4; + } + for (int w = 0; w < tmp; w++) { + vsum[0] += *din_ptr0++; + vsum[1] += *din_ptr1++; + vsum[2] += *din_ptr2++; + vsum[3] += *din_ptr3++; + } + stride += stride_n; + vst1q_f32(dst, vsum); + dst += 4; + } + if (rem_n > 1) { + const float* din_ptr0 = src + stride; + const float* din_ptr1 = din_ptr0 + width_in; + float32x4_t vsum = vdupq_n_f32(0.f); + for (int w = 0; w < cnt_w; w++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + din_ptr0 += 4; + float32x4_t vb0 = vld1q_f32(din_ptr1); + din_ptr1 += 4; + float32x4_t va1 = vld1q_f32(din_ptr0); + float32x4_t vb1 = vld1q_f32(din_ptr1); + float32x4_t vs0 = vpaddq_f32(va0, vb0); + din_ptr0 += 4; + float32x4_t vs1 = vpaddq_f32(va1, vb1); + din_ptr1 += 4; + float32x4_t vs00 = vpaddq_f32(vs0, vs1); + vsum = vaddq_f32(vs00, vsum); + } + int tmp = rem_w; + if (tmp > 3) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t vb0 = vld1q_f32(din_ptr1); + din_ptr0 += 4; + din_ptr1 += 4; + float32x4_t vs00 = vpaddq_f32(va0, vb0); + tmp -= 4; + vsum[0] += vs00[0]; + vsum[2] += vs00[1]; + vsum[1] += vs00[2]; + vsum[3] += vs00[3]; + } + vsum[0] += vsum[2]; + vsum[1] += vsum[3]; + for (int w = 0; w < tmp; w++) { + vsum[0] += *din_ptr0++; + vsum[1] += *din_ptr1++; + } + stride += width_in; + *dst++ = vsum[0]; + stride += width_in; + *dst++ = vsum[1]; + rem_n -= 2; + } + for (int n = 0; n < rem_n; n++) { + const float* din_ptr0 = src + stride; + float32x4_t vsum = vdupq_n_f32(0.f); + for (int w = 0; w < cnt_w; w++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t va1 = vld1q_f32(din_ptr0 + 4); + float32x4_t vs0 = vaddq_f32(va0, va1); + din_ptr0 += 8; + vsum = vaddq_f32(vs0, vsum); + } + if (rem_w > 3) { + float32x4_t va0 = vld1q_f32(din_ptr0); + din_ptr0 += 4; + vsum = vaddq_f32(vsum, va0); + rem_w -= 4; + } + vsum[1] += vsum[2]; + for (int w = 0; w < rem_w; w++) { + vsum[0] += *din_ptr0++; + } + vsum[1] += vsum[3]; + vsum[0] += vsum[1]; + *dst++ = vsum[0]; + } +} + +template <> +void reduce_sum_all(const float* src, float* dst, int all_size) { + int cnt_n = all_size >> 4; + int rem_n = all_size & 15; + int cnt_rem = rem_n >> 2; + int rem_rem = rem_n & 3; + float32x4_t vsum = vdupq_n_f32(0.f); + for (int n = 0; n < cnt_n; n++) { + float32x4_t va0 = vld1q_f32(src); + float32x4_t va1 = vld1q_f32(src + 4); + float32x4_t va2 = vld1q_f32(src + 8); + float32x4_t va3 = vld1q_f32(src + 12); + src += 16; + float32x4_t vs0 = vaddq_f32(va0, va1); + float32x4_t vs1 = vaddq_f32(va2, va3); + float32x4_t vs = vpaddq_f32(vs0, vs1); + vsum = vaddq_f32(vsum, vs); + } + for (int n = 0; n < cnt_rem; n++) { + float32x4_t va0 = vld1q_f32(src); + src += 4; + vsum = vaddq_f32(vsum, va0); + } + vsum[1] += vsum[2]; + for (int n = 0; n < rem_rem; n++) { + vsum[0] += *src++; + } + vsum[1] += vsum[3]; + vsum[0] += vsum[1]; + dst[0] = vsum[0]; +} + +template <> +void reduce_sum_nc(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce nc. + int num = num_in * channel_in; + int size = height_in * width_in; + reduce_sum_n(src, dst, num, size, 1, 1); +} + +template <> +void reduce_sum_ch(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int ch_size = channel_in * height_in; + int chw_size = ch_size * width_in; + for (int n = 0; n < num_in; n++) { + reduce_sum_n(src, dst, ch_size, 1, 1, width_in); + src += chw_size; + dst += width_in; + } +} + +template <> +void reduce_sum_hw(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int nc_size = num_in * channel_in; + reduce_sum_w(src, dst, nc_size, 1, 1, hw_size); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/reduce_sum.h b/lite/backends/arm/math/reduce_sum.h new file mode 100644 index 0000000000000000000000000000000000000000..74e0b6dc75d17ca5a79c4b46c8535c7f30ec1c08 --- /dev/null +++ b/lite/backends/arm/math/reduce_sum.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void reduce_sum_n(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_c(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_h(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_w(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_nc(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_ch(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_hw(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_all(const T* src, T* dst, int all_size); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/scatter.cc b/lite/backends/arm/math/scatter.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9250a9bfa3fcfbdac2a8942aeff3bd28b4bc381 --- /dev/null +++ b/lite/backends/arm/math/scatter.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/arm/math/scatter.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void scatter(const int64_t* indexs, + const float* src, + float* dst, + int index_size, + int num, + int size, + bool overwrite) { + for (int i = 0; i < num; i++) { + const float* din = src + indexs[i] * size; + memcpy(dst, din, sizeof(float) * size); + dst += size; + } + if (overwrite) { + for (int i = num; i < index_size; i++) { + const float* din = src + indexs[i] * size; + float* dout = dst + indexs[i] * size; + memcpy(dout, din, sizeof(float) * size); + } + } else { + int cnt = size >> 3; + int rem = size & 7; + for (int i = num; i < index_size; i++) { + const float* din = src + indexs[i] * size; + float* dout = dst + indexs[i] * size; + for (int j = 0; j < cnt; j++) { + float32x4_t va0 = vld1q_f32(din); + float32x4_t vb0 = vld1q_f32(dout); + float32x4_t va1 = vld1q_f32(din + 4); + float32x4_t vb1 = vld1q_f32(dout + 4); + vb0 = vaddq_f32(va0, vb0); + vb1 = vaddq_f32(va1, vb1); + din += 8; + vst1q_f32(dout, vb0); + vst1q_f32(dout + 4, vb0); + dout += 8; + } + for (int j = 0; j < rem; j++) { + dout[0] += *din++; + dout++; + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/scatter.h b/lite/backends/arm/math/scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..3d145367189eb61e7fdfbd5b20a55f5397ae702b --- /dev/null +++ b/lite/backends/arm/math/scatter.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void scatter(const int64_t* indexs, + const T* updates, + T* dst, + int index_size, + int num, + int size, + bool overwrite); +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/demo/cxx/test_cv/test_img_prepross.cc b/lite/demo/cxx/test_cv/test_img_prepross.cc index 1fe632d387cb5ed7a94ad1fcc37d4313b452d368..0e00a02260f11a05dd73d8e3850c3967533e243b 100644 --- a/lite/demo/cxx/test_cv/test_img_prepross.cc +++ b/lite/demo/cxx/test_cv/test_img_prepross.cc @@ -128,7 +128,7 @@ bool test_convert(bool cv_run, for (int i = 0; i < test_iter; i++) { clock_t begin = clock(); // resize default linear - image_preprocess.imageConvert(src, resize_lite); + image_preprocess.image_convert(src, resize_lite); clock_t end = clock(); to_lite += (end - begin); } @@ -226,7 +226,7 @@ bool test_flip(bool cv_run, for (int i = 0; i < test_iter; i++) { clock_t begin = clock(); // resize default linear - image_preprocess.imageFlip(src, resize_lite); + image_preprocess.image_flip(src, resize_lite); clock_t end = clock(); to_lite += (end - begin); } @@ -330,7 +330,7 @@ bool test_rotate(bool cv_run, for (int i = 0; i < test_iter; i++) { clock_t begin = clock(); // resize default linear - image_preprocess.imageRotate(src, resize_lite); + image_preprocess.image_rotate(src, resize_lite); clock_t end = clock(); to_lite += (end - begin); } @@ -426,7 +426,7 @@ bool test_resize(bool cv_run, for (int i = 0; i < test_iter; i++) { clock_t begin = clock(); // resize default linear - image_preprocess.imageResize(src, resize_lite); + image_preprocess.image_resize(src, resize_lite); clock_t end = clock(); to_lite += (end - begin); } @@ -526,7 +526,7 @@ bool test_crop(bool cv_run, std::cout << "lite compute:" << std::endl; for (int i = 0; i < test_iter; i++) { clock_t begin = clock(); - image_preprocess.imageCrop( + image_preprocess.image_crop( src, resize_lite, dstFormat, srcw, srch, left_x, left_y, dstw, dsth); clock_t end = clock(); to_lite += (end - begin); diff --git a/lite/demo/cxx/test_cv/test_model_cv.cc b/lite/demo/cxx/test_cv/test_model_cv.cc index caa085eecb81e54859c1bdd5cd7c0654175b7a9a..6da35ea26f13384fc663b7103d4f082ae96587bd 100644 --- a/lite/demo/cxx/test_cv/test_model_cv.cc +++ b/lite/demo/cxx/test_cv/test_model_cv.cc @@ -88,13 +88,13 @@ void pre_process(const cv::Mat& img, int width, int height, Tensor dstTensor) { uint8_t* rgb_ptr = new uint8_t[img.cols * img.rows * 3]; uint8_t* resize_ptr = new uint8_t[width * height * 3]; // do convert bgr--rgb - img_process.imageConvert(img_ptr, rgb_ptr); + img_process.image_convert(img_ptr, rgb_ptr); // do resize - img_process.imageResize(rgb_ptr, resize_ptr); + img_process.image_resize(rgb_ptr, resize_ptr); // data--tensor and normalize float means[3] = {103.94f, 116.78f, 123.68f}; float scales[3] = {0.017f, 0.017f, 0.017f}; - img_process.image2Tensor( + img_process.image_to_tensor( resize_ptr, &dstTensor, LayoutType::kNCHW, means, scales); float* data = dstTensor.mutable_data(); #else diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index ad5988c10bd7650f3fcb9c759c73117954d22dd7..40cb03872da810d54ecede0f42b996f96fbfe422 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -68,6 +68,7 @@ add_kernel(sequence_conv_compute_arm ARM extra SRCS sequence_conv_compute.cc DEP add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(reduce_prod_compute_arm ARM extra SRCS reduce_prod_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(reduce_sum_compute_arm ARM extra SRCS reduce_sum_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(split_lod_tensor_compute_arm ARM extra SRCS split_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(merge_lod_tensor_compute_arm ARM extra SRCS merge_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(anchor_generator_compute_arm ARM extra SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -79,8 +80,10 @@ add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposal add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(pixel_shuffle_compute_arm ARM extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(scatter_compute_arm ARM extra SRCS scatter_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_expand_as_compute_arm ARM extra SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps} math_arm) + # for OCR specific add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 54e67de5abbfc88f64a50b07335d2527d9738206..ba7837cfff312a15c9ec769ab4e8ac16d0945f4d 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -59,12 +59,6 @@ void ConvCompute::PrepareForRun() { bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2); bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2); -#ifdef __aarch64__ -#else - bool flag = - (stride == 1 && (paddings[0] > 1 || paddings[2] > 1)) ? false : true; - flag_dw_3x3 = flag_dw_3x3 && flag; -#endif bool flag_dw = flag_dw_3x3 || flag_dw_5x5; /// select conv impl diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index 3558eb22fbd4863771bf2b6b2e62e51b75a1227e..e34da16acdd71b490b0a233513f525668618a288 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -28,11 +28,15 @@ void DepthwiseConv::PrepareForRun() { auto& ctx = this->ctx_->template As(); auto w_dims = param.filter->dims(); auto kw = w_dims[3]; + auto channel = w_dims[0]; + auto hin = param.x->dims()[2]; + auto win = param.x->dims()[3]; auto paddings = *param.paddings; + bool ch_four = channel <= 4 * win; // select dw conv kernel if (kw == 3) { bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); - if (pads_less && paddings[0] == paddings[2] && + if (ch_four && pads_less && paddings[0] == paddings[2] && (paddings[0] == 0 || paddings[0] == 1)) { flag_trans_weights_ = false; } else { @@ -398,6 +402,14 @@ void DepthwiseConv::Run() { w_scale_.data()); } +#ifdef LITE_WITH_PROFILE +template <> +void DepthwiseConv:: + SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; +} +#endif + } // namespace arm } // namespace kernels } // namespace lite diff --git a/lite/kernels/arm/reduce_sum_compute.cc b/lite/kernels/arm/reduce_sum_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..261ed2b6a3f7ab0ea794f8e98392594afe0ad16c --- /dev/null +++ b/lite/kernels/arm/reduce_sum_compute.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/reduce_sum_compute.h" +#include +#include +#include "lite/backends/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ReduceSumCompute::Run() { + auto& param = this->template Param(); + auto* input = param.x->template data(); + auto x_dims = param.x->dims(); + int x_rank = x_dims.size(); + auto* output = param.output->template mutable_data(); + std::vector dim = param.dim; + bool keep_dim = param.keep_dim; + bool reduce_all = param.reduce_all; + + if (!dim.empty()) { + for (int i = 0; i < dim.size(); i++) { + if (dim[i] < 0) { + dim[i] += x_rank; + } + } + } + + if (reduce_all) { + lite::arm::math::reduce_sum_all(input, output, x_dims.production()); + } else { + int n_in = 1; + int c_in = 1; + int h_in = 1; + int w_in = 1; + switch (x_dims.size()) { + case 4: + w_in = x_dims[3]; + case 3: + h_in = x_dims[2]; + case 2: + c_in = x_dims[1]; + case 1: + n_in = x_dims[0]; + break; + default: + LOG(FATAL) << "x_dims.size is " << x_dims.size() + << ", which should not be over than 4."; + } + + if (dim.size() == 1) { + switch (dim[0]) { + case 0: + lite::arm::math::reduce_sum_n(input, output, n_in, c_in, h_in, w_in); + break; + case 1: + lite::arm::math::reduce_sum_c(input, output, n_in, c_in, h_in, w_in); + break; + case 2: + lite::arm::math::reduce_sum_h(input, output, n_in, c_in, h_in, w_in); + break; + case 3: + lite::arm::math::reduce_sum_w(input, output, n_in, c_in, h_in, w_in); + break; + default: + LOG(FATAL) << "dim[0] is " << dim[0] + << ", which should be less than 4."; + } + } else if (dim.size() == 2) { + if (dim[0] == 0 && dim[1] == 1) { + lite::arm::math::reduce_sum_nc(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 1 && dim[1] == 2) { + lite::arm::math::reduce_sum_ch(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 2 && dim[1] == 3) { + lite::arm::math::reduce_sum_hw(input, output, n_in, c_in, h_in, w_in); + } else { + LOG(FATAL) + << "Only support the values of the dim are 0,1 1,2 or 2,3 for now."; + } + } else { + LOG(FATAL) << "dim's size: " << dim.size() + << " over than 2, which is not supported now!!"; + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(reduce_sum, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ReduceSumCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .Finalize(); diff --git a/lite/kernels/arm/reduce_sum_compute.h b/lite/kernels/arm/reduce_sum_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..15dcc90b6474220fa7193967f14542bb102ef7a3 --- /dev/null +++ b/lite/kernels/arm/reduce_sum_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/arm/math/type_trans.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ReduceSumCompute : public KernelLite { + public: + void Run() override; + + virtual ~ReduceSumCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/scatter_compute.cc b/lite/kernels/arm/scatter_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d3a512975c26d356405deb8ae9ff58093507425 --- /dev/null +++ b/lite/kernels/arm/scatter_compute.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/scatter_compute.h" +#include "lite/backends/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ScatterCompute::Run() { + auto& param = this->template Param(); + const float* updates_data = param.updates->template data(); + const int64_t* indexs_data = param.indexs->template data(); + float* output_data = param.output->template mutable_data(); + bool overwrite = param.overwrite; + int index_size = param.indexs->dims()[0]; + auto in_dims = param.x->dims(); + int num = 1; + for (int i = 1; i < in_dims.size(); i++) { + num *= in_dims[i]; + } + lite::arm::math::scatter(indexs_data, + updates_data, + output_data, + index_size, + in_dims[0], + num, + overwrite); + if (!param.x->lod().empty()) { + param.output->set_lod(param.x->lod()); + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(scatter, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ScatterCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindInput("Updates", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .Finalize(); diff --git a/lite/kernels/arm/scatter_compute.h b/lite/kernels/arm/scatter_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..5ee37cf55dd3e9f81582ffdcc5bdf96fa8cc25a8 --- /dev/null +++ b/lite/kernels/arm/scatter_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ScatterCompute : public KernelLite { + public: + void Run() override; + + virtual ~ScatterCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 02377aad498a47cff50c3a595f6fb1634a56b5ff..6cdf815a6f03f0e36b95acc4f8e6f15dc64b4de2 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -121,6 +121,7 @@ add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${ add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS}) add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS}) add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS}) +add_operator(scatter extra SRCS scatter_op.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 494ee823827fc6d71f0c41824ee7f9e52bdbb3f4..33da913d2e13d290ef42a40955c7cdc13fd855b3 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -294,6 +294,16 @@ struct ScaleParam : ParamBase { } }; +// For Scatter OP +struct ScatterParam : ParamBase { + lite::Tensor* x{}; + lite::Tensor* indexs{}; + lite::Tensor* updates{}; + lite::Tensor* output{}; + + bool overwrite{true}; +}; + // For Softmax op struct SoftmaxParam : ParamBase { lite::Tensor* x{}; diff --git a/lite/operators/scatter_op.cc b/lite/operators/scatter_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..20a0dcb6be409c87e828e168321716adf69011e4 --- /dev/null +++ b/lite/operators/scatter_op.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/scatter_op.h" +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool ScatterOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool ScatterOp::InferShapeImpl() const { + auto index_dims = param_.indexs->dims(); + auto update_dims = param_.updates->dims(); + auto input_dims = param_.x->dims(); + for (int i = 1; i < update_dims.size(); i++) { + CHECK_EQ_OR_FALSE(update_dims[i], input_dims[i]); + } + CHECK_EQ_OR_FALSE(index_dims.size(), 1L); + param_.output->Resize(input_dims); + return true; +} + +bool ScatterOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + AttachParam(¶m_); + auto x = op_desc.Input("X").front(); + auto indexs = op_desc.Input("Ids").front(); + auto updates = op_desc.Input("Updates").front(); + auto output = op_desc.Output("Out").front(); + if (op_desc.HasAttr("overwrite")) { + param_.overwrite = op_desc.GetAttr("overwrite"); + } else { + param_.overwrite = true; + } + param_.x = scope->FindVar(x)->GetMutable(); + param_.indexs = scope->FindVar(indexs)->GetMutable(); + param_.updates = scope->FindVar(updates)->GetMutable(); + param_.output = scope->FindMutableTensor(output); + + CHECK(param_.x); + CHECK(param_.indexs); + CHECK(param_.updates); + CHECK(param_.output); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(scatter, paddle::lite::operators::ScatterOp); diff --git a/lite/operators/scatter_op.h b/lite/operators/scatter_op.h new file mode 100644 index 0000000000000000000000000000000000000000..419a5308ef76ee99987945dffb50549ca6bd4842 --- /dev/null +++ b/lite/operators/scatter_op.h @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ScatterOp : public OpLite { + public: + ScatterOp() {} + explicit ScatterOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "Scatter"; } + +#ifdef LITE_WITH_PROFILE + void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) { + ch->input_shape = ch->DimToStr(param_.x->dims()); + ch->output_shape = ch->DimToStr(param_.output->dims()); + ch->macs = param_.x->numel() * 1.f; + } +#endif + + private: + mutable ScatterParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/tests/cv/anakin/bgra_to_tensor_hwc.cc b/lite/tests/cv/anakin/bgra_to_tensor_hwc.cc index daab2f3ce559cd9583839918c79bf50109275d71..4e24f87a1d8bbddf00f898185b71b8bd312f902c 100644 --- a/lite/tests/cv/anakin/bgra_to_tensor_hwc.cc +++ b/lite/tests/cv/anakin/bgra_to_tensor_hwc.cc @@ -30,7 +30,7 @@ void bgra_to_tensor_hwc(const uint8_t* bgr, float b_scales = scales[2]; int dim8 = width >> 3; - int remain = wwidth - (dim8 << 3); + int remain = width - (dim8 << 3); float32x4_t vrmean = vdupq_n_f32(r_means); float32x4_t vgmean = vdupq_n_f32(g_means); diff --git a/lite/tests/cv/image_convert_test.cc b/lite/tests/cv/image_convert_test.cc index b1302f3396fa17471d4252e27897ec44c0110342..921c1b360c72c318d892fc39249186bd5674948a 100644 --- a/lite/tests/cv/image_convert_test.cc +++ b/lite/tests/cv/image_convert_test.cc @@ -293,53 +293,53 @@ void test_img(const std::vector& cluster_id, // LOG(INFO) << "image convert saber compute"; t_convert.Start(); - // 方法一: image_preprocess.imageCovert(src, lite_dst); - image_preprocess.imageConvert( + // method1: image_preprocess.image_convert(src, lite_dst); + image_preprocess.image_convert( src, lite_dst, (ImageFormat)srcFormat, (ImageFormat)dstFormat); t_convert.Stop(); // LOG(INFO) << "image resize saber compute"; t_resize.Start(); - // 方法一:image_preprocess.imageResize(lite_dst, resize_tmp); - image_preprocess.imageResize(lite_dst, - resize_tmp, - (ImageFormat)dstFormat, - srcw, - srch, - dstw, - dsth); + // method1:image_preprocess.image_resize(lite_dst, resize_tmp); + image_preprocess.image_resize(lite_dst, + resize_tmp, + (ImageFormat)dstFormat, + srcw, + srch, + dstw, + dsth); t_resize.Stop(); // LOG(INFO) << "image rotate saber compute"; t_rotate.Start(); - // 方法一: image_preprocess.imageRotate(resize_tmp, tv_out_ratote); - image_preprocess.imageRotate(resize_tmp, - tv_out_ratote, - (ImageFormat)dstFormat, - dstw, - dsth, - rotate); + // method1: image_preprocess.image_rotate(resize_tmp, tv_out_ratote); + image_preprocess.image_rotate(resize_tmp, + tv_out_ratote, + (ImageFormat)dstFormat, + dstw, + dsth, + rotate); t_rotate.Stop(); // LOG(INFO) << "image flip saber compute"; t_flip.Start(); - // 方法一: image_preprocess.imageFlip(resize_tmp, tv_out_flip); - image_preprocess.imageFlip( + // method1: image_preprocess.image_flip(resize_tmp, tv_out_flip); + image_preprocess.image_flip( resize_tmp, tv_out_flip, (ImageFormat)dstFormat, dstw, dsth, flip); t_flip.Stop(); // LOG(INFO) << "image to tensor compute"; t_tensor.Start(); - // 方法一: image_preprocess.image2Tensor( + // method1: image_preprocess.image_to_tensor( // resize_tmp, &dst_tensor, layout, means, scales); - image_preprocess.image2Tensor(resize_tmp, - &dst_tensor, - (ImageFormat)dstFormat, - dstw, - dsth, - layout, - means, - scales); + image_preprocess.image_to_tensor(resize_tmp, + &dst_tensor, + (ImageFormat)dstFormat, + dstw, + dsth, + layout, + means, + scales); t_tensor.Stop(); t1.Stop(); } @@ -680,7 +680,7 @@ void test_rotate(const std::vector& cluster_id, for (int i = 0; i < test_iter; ++i) { t_rotate.Start(); - image_preprocess.imageRotate(src, lite_dst); + image_preprocess.image_rotate(src, lite_dst); t_rotate.Stop(); } LOG(INFO) << "image rotate avg time : " << t_rotate.LapTimes().Avg() @@ -847,7 +847,7 @@ void test_flip(const std::vector& cluster_id, for (int i = 0; i < test_iter; ++i) { t_rotate.Start(); - image_preprocess.imageFlip(src, lite_dst); + image_preprocess.image_flip(src, lite_dst); t_rotate.Stop(); } LOG(INFO) << "image flip avg time : " << t_rotate.LapTimes().Avg() @@ -1016,7 +1016,7 @@ void test_resize(const std::vector& cluster_id, for (int i = 0; i < test_iter; ++i) { t_rotate.Start(); - image_preprocess.imageResize(src, lite_dst); + image_preprocess.image_resize(src, lite_dst); t_rotate.Stop(); } LOG(INFO) << "image Resize avg time : " << t_rotate.LapTimes().Avg() @@ -1191,7 +1191,7 @@ void test_convert(const std::vector& cluster_id, for (int i = 0; i < test_iter; ++i) { t_rotate.Start(); - image_preprocess.imageConvert(src, lite_dst); + image_preprocess.image_convert(src, lite_dst); t_rotate.Stop(); } LOG(INFO) << "image Convert avg time : " << t_rotate.LapTimes().Avg() diff --git a/lite/tests/cv/image_profiler_test.cc b/lite/tests/cv/image_profiler_test.cc index c440940bc22791ff71f5c2dffd47fd8ee31366fe..074f2e6ce8744e5c3563e4e7e56f3694f7ac5576 100644 --- a/lite/tests/cv/image_profiler_test.cc +++ b/lite/tests/cv/image_profiler_test.cc @@ -163,7 +163,7 @@ void test_convert(const std::vector& cluster_id, for (int i = 0; i < test_iter; ++i) { t_lite.Start(); - image_preprocess.imageConvert(src, lite_dst); + image_preprocess.image_convert(src, lite_dst); t_lite.Stop(); } LOG(INFO) << "image Convert avg time : " << t_lite.LapTimes().Avg() @@ -284,7 +284,7 @@ void test_resize(const std::vector& cluster_id, for (int i = 0; i < test_iter; ++i) { t_rotate.Start(); - image_preprocess.imageResize(src, lite_dst); + image_preprocess.image_resize(src, lite_dst); t_rotate.Stop(); } LOG(INFO) << "image Resize avg time : " << t_rotate.LapTimes().Avg() @@ -405,7 +405,7 @@ void test_flip(const std::vector& cluster_id, for (int i = 0; i < test_iter; ++i) { t_lite.Start(); - image_preprocess.imageFlip(src, lite_dst); + image_preprocess.image_flip(src, lite_dst); t_lite.Stop(); } LOG(INFO) << "image flip avg time : " << t_lite.LapTimes().Avg() @@ -523,7 +523,7 @@ void test_rotate(const std::vector& cluster_id, for (int i = 0; i < test_iter; ++i) { t_lite.Start(); - image_preprocess.imageRotate(src, lite_dst); + image_preprocess.image_rotate(src, lite_dst); t_lite.Stop(); } LOG(INFO) << "image rotate avg time : " << t_lite.LapTimes().Avg() @@ -667,14 +667,14 @@ void test_to_tensor(const std::vector& cluster_id, for (int i = 0; i < test_iter; ++i) { t_lite.Start(); - image_preprocess.image2Tensor(src, - &dst_tensor, - (ImageFormat)dstFormat, - dstw, - dsth, - layout, - means, - scales); + image_preprocess.image_to_tensor(src, + &dst_tensor, + (ImageFormat)dstFormat, + dstw, + dsth, + layout, + means, + scales); t_lite.Stop(); } LOG(INFO) << "image tensor avg time : " << t_lite.LapTimes().Avg() diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index b5ffe94cee83c5a51ccaf9e1d98b53bae2a49020..00fec722eb926e27492ad9c2dbeb4bff754a56de 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -66,6 +66,7 @@ if(LITE_BUILD_EXTRA) lite_cc_test(test_kernel_ctc_align_compute SRCS ctc_align_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_clip_compute SRCS clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pixel_shuffle_compute SRCS pixel_shuffle_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_scatter_compute SRCS scatter_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_expand_as_compute SRCS sequence_expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) # for training kernel diff --git a/lite/tests/kernels/interp_compute_test.cc b/lite/tests/kernels/interp_compute_test.cc index 16bc735f816943e38c03b22c6f04ac5701132191..f512808632f3d99153c1ca93c94e3edc679b9c96 100644 --- a/lite/tests/kernels/interp_compute_test.cc +++ b/lite/tests/kernels/interp_compute_test.cc @@ -420,7 +420,6 @@ void TestInterpAlignMode(Place place, float abs_error = 2e-5) { if (place == TARGET(kARM) && align_mode == 1 && !align_corners) { continue; } - // align_mode = 0 && align_corners = false NOT supported in Huawei // Ascend NPU DDK if (place == TARGET(kHuaweiAscendNPU) && align_mode == 0 && !align_corners) { diff --git a/lite/tests/kernels/reduce_sum_compute_test.cc b/lite/tests/kernels/reduce_sum_compute_test.cc index 18490e2f9e2a8c98c2d54ac599a34d0c42e7d825..c38132a1a084a5e133afdb273ed89680454fa385 100644 --- a/lite/tests/kernels/reduce_sum_compute_test.cc +++ b/lite/tests/kernels/reduce_sum_compute_test.cc @@ -340,10 +340,10 @@ TEST(ReduceSum, precision) { Place place(TARGET(kX86)); test_reduce_sum(place); #endif - // #ifdef LITE_WITH_ARM - // Place place(TARGET(kARM)); - // test_reduce_sum(place); - // #endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_reduce_sum(place); +#endif } } // namespace lite diff --git a/lite/tests/kernels/scatter_compute_test.cc b/lite/tests/kernels/scatter_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2d82b38d986deafb619d61e97e20be759c48b98 --- /dev/null +++ b/lite/tests/kernels/scatter_compute_test.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +void scatter_basic(const int64_t* indexs, + const float* src, + float* dst, + int index_size, + int num, + int size, + bool overwrite) { + for (int i = 0; i < num; i++) { + const float* din = src + indexs[i] * size; + memcpy(dst, din, sizeof(float) * size); + dst += size; + } + if (overwrite) { + for (int i = num; i < index_size; i++) { + const float* din = src + indexs[i] * size; + float* dout = dst + indexs[i] * size; + memcpy(dout, din, sizeof(float) * size); + } + } else { + for (int i = num; i < index_size; i++) { + const float* din = src + indexs[i] * size; + float* dout = dst + indexs[i] * size; + for (int j = 0; j < size; j++) { + dout[j] += din[j]; + } + } + } +} + +class ScatterComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string input_ = "x"; + std::string indexs_ = "indexs"; + std::string updates_ = "updates"; + std::string output_ = "out"; + DDim up_dims_{{1}}; + DDim id_dims_{{1}}; + DDim x_dims_{{1}}; + int index_size_ = 0; + bool overwrite_ = false; + + public: + ScatterComputeTester(const Place& place, + const std::string& alias, + DDim up_dims, + DDim id_dims, + DDim x_dims, + bool overwrite, + int index_size) + : TestCase(place, alias), + up_dims_(up_dims), + id_dims_(id_dims), + x_dims_(x_dims), + index_size_(index_size), + overwrite_(overwrite) {} + + void RunBaseline(Scope* scope) override { + auto* indexs_t = scope->FindMutableTensor(indexs_); + auto* updates_t = scope->FindMutableTensor(updates_); + const auto* indexs_data = indexs_t->data(); + const auto* updates_data = updates_t->data(); + auto* out = scope->NewTensor(output_); + + out->Resize(x_dims_); + + auto* out_data = out->mutable_data(); + int in_n = x_dims_[0]; + int in_c = x_dims_[1]; + int in_h = x_dims_[2]; + int in_w = x_dims_[3]; + int size = in_c * in_h * in_w; + + scatter_basic(indexs_data, + updates_data, + out_data, + index_size_, + in_n, + size, + overwrite_); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("scatter"); + op_desc->SetInput("X", {input_}); + op_desc->SetInput("Ids", {indexs_}); + op_desc->SetInput("Updates", {updates_}); + op_desc->SetOutput("Out", {output_}); + op_desc->SetAttr("overwrite", overwrite_); + } + + void PrepareData() override { + std::vector data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + data[i] = i * 1.0; + } + SetCommonTensor(input_, x_dims_, data.data()); + std::vector update(up_dims_.production()); + for (int i = 0; i < up_dims_.production(); i++) { + update[i] = i * 1.0; + } + SetCommonTensor(updates_, up_dims_, update.data()); + std::vector index(id_dims_.production()); + for (int i = 0; i < id_dims_.production(); i++) { + index[i] = i; + } + SetCommonTensor(indexs_, id_dims_, index.data()); + } +}; + +void test_scatter(Place place) { + for (auto n : {1, 3}) { + for (auto c : {1, 2}) { + for (auto h : {1, 3}) { + for (auto w : {1, 3}) { + for (bool overwrite : {false, true}) { + auto x_dims = DDim(std::vector({n, c, h, w})); + auto up_dims = DDim(std::vector({n, c, h, w})); + auto id_dims = DDim(std::vector({n})); + std::unique_ptr tester(new ScatterComputeTester( + place, "def", up_dims, id_dims, x_dims, overwrite, n)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } +} + +TEST(Scatter, precision) { +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_scatter(place); +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/math/sgemm_compute_test.cc b/lite/tests/math/sgemm_compute_test.cc index b3ca5ec6ed9876141f8e3d49451b2a9d0fda6269..e5cfe9d5588dbccd1eed51b983e5e926dfd4cdc6 100644 --- a/lite/tests/math/sgemm_compute_test.cc +++ b/lite/tests/math/sgemm_compute_test.cc @@ -39,7 +39,13 @@ DEFINE_int32(power_mode, DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); + +#ifdef LITE_WITH_ARM DEFINE_bool(basic_test, true, "do all tests"); +#else +DEFINE_bool(basic_test, false, "do all tests"); +#endif + DEFINE_bool(check_result, true, "check the result"); DEFINE_int32(M, 512, "gemm: M");