diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index d163b8862a082a06742b7344e34a5adb0b0cc871..c01a068fb9732b64da4097844736f7484fdfcab9 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -31,12 +31,7 @@ void ConvAddBasic(const FusionConvAddParam ¶m) { Tensor bias = *param.Bias(); int axis = param.Axis(); Tensor *output = param.Output(); - math::expand_bias(bias, axis, output->dims()); - float *output_data = output->data(); float *biase_data = bias.data(); - for (int k = 0; k < output->numel(); ++k) { - output_data[k] = biase_data[k]; - } int groups = param.Groups(); std::vector strides = param.Strides(); @@ -113,7 +108,7 @@ void ConvAddBasic(const FusionConvAddParam ¶m) { Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); math::matmul(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, - static_cast(1)); + static_cast(1), false, biase_data); } } } diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h index 177c275224b3ccbd5fa31efc2fab4bfa8033b752..7b019b60db98d87e4de9315e96fedca7929d4add 100644 --- a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h @@ -32,12 +32,12 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { Tensor bias = *param.Bias(); int axis = param.Axis(); Tensor *output = param.Output(); - math::expand_bias(bias, axis, output->dims()); + // math::expand_bias(bias, axis, output->dims()); float *output_data = output->data(); float *biase_data = bias.data(); - for (int k = 0; k < output->numel(); ++k) { - output_data[k] = biase_data[k]; - } + // for (int k = 0; k < output->numel(); ++k) { + // output_data[k] = biase_data[k]; + // } int groups = param.Groups(); std::vector strides = param.Strides(); @@ -115,7 +115,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); math::matmul(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, - static_cast(1), true); + static_cast(1), true, biase_data); } } } diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 33caded3afaaf125bac9108f2fafeda3d3c2049f..41acb973409d9655ae47a8655c1cb527e9563775 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -30,6 +30,7 @@ inline void ConvBasic(const ConvParam ¶m) { Tensor filter = *param.Filter(); Tensor *output = param.Output(); output->mutable_data(); + float *bias_data = output->mutable_data(); int groups = param.Groups(); std::vector strides = param.Strides(); std::vector paddings = param.Paddings(); @@ -106,7 +107,7 @@ inline void ConvBasic(const ConvParam ¶m) { Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); math::matmul(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, - static_cast(0)); + static_cast(0), false, bias_data); } } } diff --git a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h index 431124feb4f7baddf102dcbfad5e53b0c2002dda..4a689dfc18e3b8677faa61b5c90cb46321f3f4c3 100644 --- a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h +++ b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h @@ -30,6 +30,7 @@ void FusionFcCompute(const FusionFcParam ¶m) { int axis = param.Axis(); Tensor *out = param.Out(); auto *out_data = out->mutable_data(); + float *bias_data = out->mutable_data(); const Tensor x_matrix = input_x->dims().size() > 2 ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) @@ -47,18 +48,18 @@ void FusionFcCompute(const FusionFcParam ¶m) { PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0], " out_dim.size must be 2."); axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis); - PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ") + PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. "); int64_t classes = input_z->numel(); for (int i = 0; i < out_dim[0]; i++) { memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes); } - for (int i = 0; i < out->numel(); i++) { - DLOG << out_data[i]; - } + // for (int i = 0; i < out->numel(); i++) { + // DLOG << out_data[i]; + // } math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(1)); + out, static_cast(1), false, bias_data); PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); // if (out_dim.size() != 2) { // out->Resize(out_dim); diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index d2da67afe1d2eb746971a2443bdb449eb2b66ec4..341759a96e1e7216fb9550596d3d3533dd0ab80a 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -59,6 +59,7 @@ void MulCompute(const MulParam ¶m) { const Tensor *input_y = param.InputY(); Tensor *out = param.Out(); out->mutable_data(); + float *bias_data = out->mutable_data(); const Tensor x_matrix = input_x->dims().size() > 2 ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) @@ -72,7 +73,7 @@ void MulCompute(const MulParam ¶m) { out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(0)); + out, static_cast(0), false, bias_data); if (out_dim.size() != 2) { out->Resize(out_dim); } diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 20d71907ff9e391d97ce75e38b6e08dc1286a9a3..ef1625b72c54b168eb3b58a4126d2500fbfe561f 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -373,9 +373,9 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, #endif } } - if (alpha != 1) { WriteWithAlphaBeta(mc, nc, c, C, ldc); + return; } if (beta == 0) { @@ -392,6 +392,42 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, } } +// 分块矩阵乘法 +void InnerKernelWithBias(int mc, int nc, float alpha, const float *a, + const float *b, float beta, float *c, float *C, + int ldc, bool relu, float *bias) { +#pragma omp parallel for + for (int j = 0; j < nc; j += NR) { + for (int i = 0; i < mc; i += MR) { +#if __aarch64__ + // AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#else + // AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#endif + } + } + + if (alpha != 1) { + WriteWithAlphaBeta(mc, nc, c, C, ldc); + return; + } + if (beta == 0) { + WriteBasic(mc, nc, c, C, ldc); + return; + } + if (beta == 1 && !relu) { + WriteWithAddV1(mc, nc, c, C, ldc, bias); + return; + } + if (beta == 1 && relu) { + WriteWithAddReluV1(mc, nc, c, C, ldc, bias); + return; + } +} + // 分块矩阵乘法 void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, const float *b, float beta, float *c, float *C, int ldc, @@ -577,6 +613,43 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { } } } +// C = A * B + bias +void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t biasv; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + biasv = vld1q_dup_f32(bias + i); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; + } + } + } +} // C = A * B + C, relu(C) void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { @@ -619,6 +692,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { } } +// C = A * B + bias, relu(C) +void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, + float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t biasv; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + biasv = vld1q_dup_f32(bias + i); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; + } + } + } +} + // C = A * B, batchnorm(C) void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias) { @@ -1448,6 +1563,44 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { } } +// C = A * B + bias +void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t biasv; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + biasv = vld1q_dup_f32(bias + i); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; + } + } + } +} + // C = A * B + C, relu(C) void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { int nc1 = nc / 16; @@ -1522,6 +1675,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { } } +// C = A * B + bias, relu(C) +void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, + float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t biasv; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + biasv = vld1q_dup_f32(bias + i); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; + } + } + } +} + // C = A * B, batchnorm(C) void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, float *bias) { @@ -2053,7 +2248,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { // 32位 float 矩阵乘法 void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, int ldc, bool relu) { + const float *B, int ldb, float beta, float *C, int ldc, bool relu, + float *bias) { // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) // L2 cache is 0.5~4 Mib (Contex-A72 cluster) int L1 = 32 * 1024; @@ -2103,8 +2299,8 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, #else PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); #endif - InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc, - relu); + InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC, + &C(i, j), ldc, relu, bias + i); } } diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index a9593b15ae73f46aa287028ba74efdb0d303fdde..625fce0323580545c1655c1d3c325f995aa054f2 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -62,6 +62,9 @@ void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, // 分块矩阵乘法 void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, float beta, float *c, float *C, int ldc, bool relu); +void InnerKernelWithBias(int mc, int nc, float alpha, const float *a, + const float *b, float beta, float *c, float *C, + int ldc, bool relu, float *bias); void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, const float *b, float beta, float *c, float *C, int ldc, @@ -91,8 +94,13 @@ void WriteBasic(int mc, int nc, float *c, float *C, int ldc); void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc); // C = A * B + C void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc); +// C = A * B + bias +void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias); // C = A * B + C, relu(C) void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc); +// C = A * B + bias ,relu(C) +void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, + float *bias); // C = A * B, batchnorm(C) void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias); @@ -120,7 +128,8 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale, // 32位 float 矩阵乘法 void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, int ldc, bool relu); + const float *B, int ldb, float beta, float *C, int ldc, bool relu, + float *bias); // 32位 float 矩阵乘法, 并对结果进行 batchnrom void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index d881014ccb3f29393ca73fa0e7f4792d4c0d65c7..9ac8d79e89b7a577f0a89807dc96c9f368fed6de 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -22,7 +22,8 @@ namespace math { template <> void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, - framework::Tensor *matrix_out, float beta, bool relu) { + framework::Tensor *matrix_out, float beta, bool relu, + float *bias) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -42,7 +43,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, int K = (!trans_a) ? dim_a[1] : dim_a[0]; Sgemm(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, - beta, matrix_out->data(), N, relu); + beta, matrix_out->data(), N, relu, bias); } template <> diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index b5179458a2bf9e6817366c7bd4ea1f536fd21642..74a3f5b8f58f5817c3de426d723a273a8a041614 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -21,11 +21,11 @@ namespace paddle_mobile { namespace operators { namespace math { -// matrix multiply with continuous memory template void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, - framework::Tensor *matrix_out, T beta, bool relu = false); + framework::Tensor *matrix_out, T beta, bool relu = false, + float *bias = nullptr); template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 260236e24ea44a6fc5708d4d0dac239252d28945..c505c61fce21775136a368949a451999b97b3069 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -49,9 +49,9 @@ int main() { auto time1 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul(aa, false, bb, false, - static_cast(1), &cc, - static_cast(0), false); + paddle_mobile::operators::math::matmul( + aa, false, bb, false, static_cast(1), &cc, static_cast(0), + false, biasptr); // paddle_mobile::operators::math::matmulWithBn( // aa, false, bb, false, static_cast(1), &cc,