From 7db044ac589e89fe68732ad26e836516ca9a5914 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Wed, 19 Dec 2018 16:51:22 +0800 Subject: [PATCH] Refactor batch norm, and fix deconv memory leak --- src/common/types.cpp | 6 +- src/common/types.h | 3 +- .../central-arm-func/batchnorm_arm_func.h | 312 +++--------------- .../multiclass_nms_arm_func.h | 5 - src/operators/math/math_function.cpp | 5 +- 5 files changed, 54 insertions(+), 277 deletions(-) diff --git a/src/common/types.cpp b/src/common/types.cpp index 6ecc62cfbe..420c789e3f 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -22,6 +22,8 @@ const char *G_OP_TYPE_BATCHNORM = "batch_norm"; const char *G_OP_TYPE_BOX_CODER = "box_coder"; const char *G_OP_TYPE_CONCAT = "concat"; const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; +const char *G_OP_TYPE_ELEMENTWISE_SUB = "elementwise_sub"; +const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul"; const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant"; const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu"; @@ -67,7 +69,6 @@ const char *G_OP_TYPE_CRF = "crf_decoding"; const char *G_OP_TYPE_BILINEAR_INTERP = "bilinear_interp"; const char *G_OP_TYPE_FLATTEN = "flatten"; const char *G_OP_TYPE_SHAPE = "shape"; -const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul"; const char *G_OP_TYPE_SUM = "sum"; const char *G_OP_TYPE_TOP_K = "top_k"; const char *G_OP_TYPE_CAST = "cast"; @@ -102,6 +103,8 @@ std::unordered_map< {G_OP_TYPE_SIGMOID, {{"X"}, {"Out"}}}, {G_OP_TYPE_MUL, {{"X"}, {"Out"}}}, {G_OP_TYPE_ELEMENTWISE_ADD, {{"X", "Y"}, {"Out"}}}, + {G_OP_TYPE_ELEMENTWISE_SUB, {{"X", "Y"}, {"Out"}}}, + {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_POOL2D, {{"X"}, {"Out"}}}, {G_OP_TYPE_BATCHNORM, {{"X"}, {"Y"}}}, {G_OP_TYPE_LRN, {{"X"}, {"Out"}}}, @@ -146,7 +149,6 @@ std::unordered_map< {G_OP_TYPE_SUM, {{"X"}, {"Out"}}}, {G_OP_TYPE_TOP_K, {{"X"}, {"Out", "Indices"}}}, {G_OP_TYPE_CAST, {{"X"}, {"Out"}}}, - {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_FUSION_DEQUANT_BN, {{"X", "Scale"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index e1e1a94900..c12e5b6a26 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -112,6 +112,8 @@ extern const char *G_OP_TYPE_BATCHNORM; extern const char *G_OP_TYPE_BOX_CODER; extern const char *G_OP_TYPE_CONCAT; extern const char *G_OP_TYPE_ELEMENTWISE_ADD; +extern const char *G_OP_TYPE_ELEMENTWISE_SUB; +extern const char *G_OP_TYPE_ELEMENTWISE_MUL; extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU; extern const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU; extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU; @@ -149,7 +151,6 @@ extern const char *G_OP_TYPE_FUSION_CONV_BN; extern const char *G_OP_TYPE_CONV_TRANSPOSE; extern const char *G_OP_TYPE_PRELU; extern const char *G_OP_TYPE_SUM; -extern const char *G_OP_TYPE_ELEMENTWISE_MUL; extern const char *G_OP_TYPE_TOP_K; extern const char *G_OP_TYPE_CAST; diff --git a/src/operators/kernel/central-arm-func/batchnorm_arm_func.h b/src/operators/kernel/central-arm-func/batchnorm_arm_func.h index 1723835a6a..300cd32a69 100644 --- a/src/operators/kernel/central-arm-func/batchnorm_arm_func.h +++ b/src/operators/kernel/central-arm-func/batchnorm_arm_func.h @@ -18,283 +18,63 @@ limitations under the License. */ #include #include "operators/op_param.h" +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif // __ARM_NEON__ namespace paddle_mobile { namespace operators { template void BatchnormCompute(const BatchNormParam ¶m) { - const Tensor *input_x = param.InputX(); - auto input_x_ptr = input_x->data(); - const auto &x_dims = input_x->dims(); - const int N = x_dims[0]; - const int C = x_dims[1]; - const int H = x_dims[2]; - const int W = x_dims[3]; - const int stride0 = C * H * W; - const int stride1 = H * W; - const int stride2 = W; - Tensor *out = param.OutputY(); - auto out_ptr = out->mutable_data(); const float epsilon = param.Epsilon(); - const Tensor *mean = param.InputMean(); - const Tensor *variance = param.InputVariance(); - const Tensor *scale = param.InputScale(); - const Tensor *bias = param.InputBias(); - auto mean_ptr = mean->data(); - auto variance_ptr = variance->data(); - auto scale_ptr = scale->data(); - auto bias_ptr = bias->data(); - - // Tensor inv_std; - // auto inv_std_ptr = inv_std.mutable_data(make_ddim({C})); - - PADDLE_MOBILE_ENFORCE(C == variance->numel(), - "C must equal to variance.numel()"); - - int HXW = H * W; - -#if __ARM_NEON -#if __aarch64__ - float *inv_std_ptr = new float[C]; - for (int i = 0; i < C; i++) { - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); - } - - Tensor new_scale; - auto new_scale_ptr = new_scale.mutable_data(framework::make_ddim({C})); - Tensor new_bias; - auto new_bias_ptr = new_bias.mutable_data(framework::make_ddim({C})); - - /// ((x - est_mean) * (inv_var) * scale + bias equal to - /// (x * inv_var * scale) + (bias - est_mean * inv_var * scale) - for (int i = 0; i < C; i++) { - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; - new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; - { - for (int n = 0; n < N; n++) { - for (int h = 0; h < H; h++) { - int tmp_index = n * stride0 + i * stride1 + h * stride2; - for (int w = 0; w < W; w++) { - int index = tmp_index + w; - out_ptr[index] = - input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i]; - } - } + const float *mean_ptr = param.InputMean()->data(); + const float *variance_ptr = param.InputVariance()->data(); + const float *scale_ptr = param.InputScale()->data(); + const float *bias_ptr = param.InputBias()->data(); + + const framework::Tensor *input = param.InputX(); + const float *input_ptr = input->data(); + framework::Tensor *output = param.OutputY(); + float *output_ptr = output->mutable_data(); + size_t spatial_size = output->dims()[2] * output->dims()[3]; + int channels = output->dims()[1]; + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < output->dims()[0]; ++batch) { + for (int c = 0; c < channels; ++c) { + float inv_scale = 1.f / (std::sqrt(variance_ptr[c] + epsilon)); + float bias = bias_ptr[c] - inv_scale * scale_ptr[c] * mean_ptr[c]; + float scale = inv_scale * scale_ptr[c]; + size_t offset = (batch * channels + c) * spatial_size; + const float *x = input_ptr + offset; + float *y = output_ptr + offset; + size_t remain = spatial_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = spatial_size >> 4; + remain = spatial_size & 0xF; + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __bias = vdupq_n_f32(bias); + for (int k = 0; k < loop; ++k, x += 16, y += 16) { + float32x4_t r0 = vld1q_f32(x); + float32x4_t r1 = vld1q_f32(x + 4); + float32x4_t r2 = vld1q_f32(x + 8); + float32x4_t r3 = vld1q_f32(x + 12); + r0 = vmlaq_f32(__bias, __scale, r0); + r1 = vmlaq_f32(__bias, __scale, r1); + r2 = vmlaq_f32(__bias, __scale, r2); + r3 = vmlaq_f32(__bias, __scale, r3); + vst1q_f32(y, r0); + vst1q_f32(y + 4, r1); + vst1q_f32(y + 8, r2); + vst1q_f32(y + 12, r3); } - } - } - delete[] inv_std_ptr; -#else - - if (HXW > 32) { - int NXC = N * C; - float *inv_std_ptr = new float[NXC * 4]; - float *volatile new_scale_ptr = new float[NXC * 4]; - float *volatile new_bias_ptr = new float[NXC * 4]; - - /// std = (var + epsilon).sqrt(); - /// inv_std = 1 / std; - for (int i = 0; i < C * 4; i += 4) { - int index = i / 4; - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[index] + epsilon), 0.5)); - inv_std_ptr[i + 1] = inv_std_ptr[i]; - inv_std_ptr[i + 2] = inv_std_ptr[i]; - inv_std_ptr[i + 3] = inv_std_ptr[i]; - - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[index]; - new_scale_ptr[i + 1] = new_scale_ptr[i]; - new_scale_ptr[i + 2] = new_scale_ptr[i]; - new_scale_ptr[i + 3] = new_scale_ptr[i]; - - new_bias_ptr[i] = - bias_ptr[index] - mean_ptr[index] * inv_std_ptr[i] * scale_ptr[index]; - - new_bias_ptr[i + 1] = new_bias_ptr[i]; - new_bias_ptr[i + 2] = new_bias_ptr[i]; - new_bias_ptr[i + 3] = new_bias_ptr[i]; - } - - for (int j = C * 4; j < NXC * 4; ++j) { - new_scale_ptr[j] = new_scale_ptr[j - C * 4]; - new_bias_ptr[j] = new_bias_ptr[j - C * 4]; - } - - asm volatile( - "subs %[N], %[N], #1 \n\t" - "blt end_n_%= \n\t" - "loop_n_%=: \n\t" - - "subs %[C], %[C], #1 \n\t" - "blt end_c_%= \n\t" - "loop_c_%=: \n\t" - - "vld1.32 {q9}, [%[new_scale_ptr]]! \n\t" - "vld1.32 {q10}, [%[new_bias_ptr]]! \n\t" - - "mov r6, %[HXW] \n\t" - - "subs r6, r6, #32 \n\t" - "blt end_hw_%= \n\t" - "loop_hw_%=: \n\t" - - "vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t" - - "vmul.f32 q1, q1, q9 \n\t" - "vmul.f32 q2, q2, q9 \n\t" - "vmul.f32 q3, q3, q9 \n\t" - "vmul.f32 q4, q4, q9 \n\t" - - "vmul.f32 q5, q5, q9 \n\t" - "vmul.f32 q6, q6, q9 \n\t" - "vmul.f32 q7, q7, q9 \n\t" - "vmul.f32 q8, q8, q9 \n\t" - - "vadd.f32 q1, q1, q10 \n\t" - "vadd.f32 q2, q2, q10 \n\t" - "vadd.f32 q3, q3, q10 \n\t" - "vadd.f32 q4, q4, q10 \n\t" - "vadd.f32 q5, q5, q10 \n\t" - "vadd.f32 q6, q6, q10 \n\t" - "vadd.f32 q7, q7, q10 \n\t" - "vadd.f32 q8, q8, q10 \n\t" - - "vst1.32 {q1, q2}, [%[out_ptr]]! \n\t" - "vst1.32 {q3, q4}, [%[out_ptr]]! \n\t" - "vst1.32 {q5, q6}, [%[out_ptr]]! \n\t" - "vst1.32 {q7, q8}, [%[out_ptr]]! \n\t" - - "subs r6, r6, #32 \n\t" - "bge loop_hw_%= \n\t" - "end_hw_%=: \n\t" - - "cmp r6, #0 \n\t" - "bge end_remainder_%= \n\t" - "mov r5, #4 \n\t" - "mul r6, r6, r5 \n\t" - "add %[input_x_ptr], %[input_x_ptr], r6 \n\t" - - "vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t" - - "vmul.f32 q1, q1, q9 \n\t" - "vmul.f32 q2, q2, q9 \n\t" - "vmul.f32 q3, q3, q9 \n\t" - "vmul.f32 q4, q4, q9 \n\t" - "vmul.f32 q5, q5, q9 \n\t" - "vmul.f32 q6, q6, q9 \n\t" - "vmul.f32 q7, q7, q9 \n\t" - "vmul.f32 q8, q8, q9 \n\t" - "vadd.f32 q1, q1, q10 \n\t" - "vadd.f32 q2, q2, q10 \n\t" - "vadd.f32 q3, q3, q10 \n\t" - "vadd.f32 q4, q4, q10 \n\t" - "vadd.f32 q5, q5, q10 \n\t" - "vadd.f32 q6, q6, q10 \n\t" - "vadd.f32 q7, q7, q10 \n\t" - "vadd.f32 q8, q8, q10 \n\t" - - "add %[out_ptr], %[out_ptr], r6 \n\t" - "vst1.32 {q1, q2}, [%[out_ptr]]! \n\t" - "vst1.32 {q3, q4}, [%[out_ptr]]! \n\t" - "vst1.32 {q5, q6}, [%[out_ptr]]! \n\t" - "vst1.32 {q7, q8}, [%[out_ptr]]! \n\t" - - "end_remainder_%=: \n\t" - - "subs %[C], %[C], #1 \n\t" - "bge loop_c_%= \n\t" - "end_c_%=: \n\t" - - "subs %[N], %[N], #1 \n\t" - "bge loop_n_%= \n\t" - "end_n_%=: \n\t" - : - : [input_x_ptr] "r"(input_x_ptr), [out_ptr] "r"(out_ptr), - [new_scale_ptr] "r"(new_scale_ptr), [new_bias_ptr] "r"(new_bias_ptr), - [N] "r"(N), [C] "r"(C), [HXW] "r"(HXW) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "r5", "r6"); - - delete[] inv_std_ptr; - delete[] new_scale_ptr; - delete[] new_bias_ptr; - - } else { - float *inv_std_ptr = new float[C]; - for (int i = 0; i < C; i++) { - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); - } - - Tensor new_scale; - auto new_scale_ptr = - new_scale.mutable_data(framework::make_ddim({C})); - Tensor new_bias; - auto new_bias_ptr = new_bias.mutable_data(framework::make_ddim({C})); - - /// ((x - est_mean) * (inv_var) * scale + bias equal to - /// (x * inv_var * scale) + (bias - est_mean * inv_var * scale) - for (int i = 0; i < C; i++) { - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; - new_bias_ptr[i] = - bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; - { - for (int n = 0; n < N; n++) { - for (int h = 0; h < H; h++) { - int tmp_index = n * stride0 + i * stride1 + h * stride2; - for (int w = 0; w < W; w++) { - int index = tmp_index + w; - out_ptr[index] = - input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i]; - } - } - } - } - } - - delete[] inv_std_ptr; - } -#endif -#else - float *inv_std_ptr = new float[C]; - for (int i = 0; i < C; i++) { - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); - } - - Tensor new_scale; - auto new_scale_ptr = new_scale.mutable_data(framework::make_ddim({C})); - Tensor new_bias; - auto new_bias_ptr = new_bias.mutable_data(framework::make_ddim({C})); - - /// ((x - est_mean) * (inv_var) * scale + bias equal to - /// (x * inv_var * scale) + (bias - est_mean * inv_var * scale) - for (int i = 0; i < C; i++) { - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; - new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; - { - for (int n = 0; n < N; n++) { - for (int h = 0; h < H; h++) { - int tmp_index = n * stride0 + i * stride1 + h * stride2; - for (int w = 0; w < W; w++) { - int index = tmp_index + w; - out_ptr[index] = - input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i]; - } - } +#endif // __ARM_NEON__ + for (int k = 0; k < remain; ++k) { + y[k] = scale * x[k] + bias; } } } - delete[] inv_std_ptr; -#endif } } // namespace operators diff --git a/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h b/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h index 533edd69b6..b021a574d6 100644 --- a/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h +++ b/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h @@ -294,11 +294,6 @@ void MultiClassNMSCompute(const MultiClassNMSParam& param) { } } } - - // framework::LoD lod; - // lod.emplace_back(batch_starts); - // - // outs->set_lod(lod); } } // namespace operators diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index d672dbc607..b1e49e377b 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -56,14 +56,13 @@ void MatMul(const framework::Tensor &matrix_a, bool trans_a, int N = dim_out[1]; int K = (!trans_a) ? dim_a[1] : dim_a[0]; Gemm gemm; - if (trans_a) { + framework::Tensor matrix_trans; int numel = matrix_a.numel(); int m = matrix_a.dims()[0]; int n = matrix_a.dims()[1]; float *tmp = (float *)(matrix_a.data()); // NOLINT - float *a = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * numel)); + float *a = matrix_trans.mutable_data(matrix_a.dims()); int index = 0; for (int j = 0; j < n; j++) { for (int i = 0; i < m; i++) { -- GitLab