From cd1b6c089e67af1f9398648225dd2d92ea177e34 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 12 Mar 2019 16:19:50 +0800 Subject: [PATCH] Optimize vector-matrix and matrix-vector multiply --- .../kernel/arm/convolution/conv_kernel.cpp | 2 - .../arm/convolution/dwconv_bn_relu_kernel.cpp | 2 - .../kernel/central-arm-func/conv_arm_func.cpp | 38 ++-- src/operators/math/depthwise_conv3x3.cpp | 69 ++++--- src/operators/math/depthwise_conv3x3_int8.cpp | 3 +- src/operators/math/depthwise_conv5x5.cpp | 25 ++- src/operators/math/gemm/cblas.cc | 14 +- src/operators/math/gemm/executor.h | 5 +- src/operators/math/gemm/gemm_kernel.h | 194 ++++++++++++++++++ src/operators/math/gemm/pack_kernel.h | 37 +++- src/operators/math/gemm/strategy.h | 17 +- src/operators/math/math.h | 12 ++ test/common/test_gemm_accuracy.cpp | 101 +++++---- 13 files changed, 370 insertions(+), 149 deletions(-) diff --git a/src/operators/kernel/arm/convolution/conv_kernel.cpp b/src/operators/kernel/arm/convolution/conv_kernel.cpp index 1c6ac2015d..a819aa5021 100644 --- a/src/operators/kernel/arm/convolution/conv_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_kernel.cpp @@ -18,8 +18,6 @@ limitations under the License. */ #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_arm_func.h" -#include - namespace paddle_mobile { namespace operators { diff --git a/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp index 748845e23e..063d51330e 100644 --- a/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp @@ -65,14 +65,12 @@ void DWConvBNReluKernel::Compute( case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: DepthwiseConv3x3(param); break; -#ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); break; case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.cpp b/src/operators/kernel/central-arm-func/conv_arm_func.cpp index 495963d470..c34bd1f5d9 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.cpp +++ b/src/operators/kernel/central-arm-func/conv_arm_func.cpp @@ -190,18 +190,22 @@ void DepthwiseConv3x3(const ConvParam ¶m) { Tensor *output = param.Output(); output->mutable_data(); - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1); - Tensor out_batch = output->Slice(i, i + 1); - if (strides[0] == 1) { + if (strides[0] == 1) { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); math::DepthwiseConv3x3S1(in_batch, *filter, paddings, &out_batch); - } else if (strides[0] == 2) { + } + } else if (strides[0] == 2) { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); math::DepthwiseConv3x3S2(in_batch, *filter, paddings, &out_batch); - } else { - GemmConv(param); } + } else { + GemmConv(param); } } @@ -215,16 +219,16 @@ void DepthwiseConv5x5(const ConvParam ¶m) { Tensor *output = param.Output(); output->mutable_data(); - // if (strides[0] == 1) { - // for (int i = 0; i < batch_size; i++) { - // Tensor in_batch = input->Slice(i, i + 1); - // Tensor out_batch = output->Slice(i, i + 1); - // math::DepthwiseConv5x5S1(in_batch, *filter, paddings, - // &out_batch); - // } - // } else { - GemmConv(param); - // } + if (strides[0] == 1) { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + math::DepthwiseConv5x5S1(in_batch, *filter, paddings, + &out_batch); + } + } else { + GemmConv(param); + } } template void GemmConv(const ConvParam ¶m); diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index fe571918ba..62fae35060 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -73,8 +73,11 @@ inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter, const int h_start = h_in_start > 0 ? h_in_start : 0; const int h_end = h_in_end < input_h ? h_in_end : input_h; - const int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; - const int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1; + int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; + int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1; + if (valid_w_end < valid_w_start) { + valid_w_end = valid_w_start; + } // const int valid_w_end = output_w - valid_w_start; float *output_ptr = output + h_output * output_w; // border left @@ -120,7 +123,7 @@ inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter, vst1_f32(output_ptr0, vget_low_f32(_sum)); break; case 1: - vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0); + vst1q_lane_f32(output_ptr0, _sum, 0); break; } } @@ -136,20 +139,21 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, const float *input_data = input.data(); const float *filter_data = filter.data(); float *out_data = output->mutable_data(); - int input_h = input.dims()[2]; - int input_w = input.dims()[3]; - int output_h = output->dims()[2]; - int output_w = output->dims()[3]; - int padding_h = paddings[0]; - int padding_w = paddings[1]; - int image_size = input_h * input_w; - int out_image_size = output_h * output_w; - int valid_h_start = padding_h; - int valid_h_end = output_h - valid_h_start; - int valid_h = valid_h_end - valid_h_start; - int valid_w_start = padding_w; - int valid_w_end = output_w - valid_w_start; - int valid_w = valid_w_end - valid_w_start; + + const int input_h = input.dims()[2]; + const int input_w = input.dims()[3]; + const int output_h = output->dims()[2]; + const int output_w = output->dims()[3]; + const int padding_h = paddings[0]; + const int padding_w = paddings[1]; + const int image_size = input_h * input_w; + const int out_image_size = output_h * output_w; + const int valid_h_start = padding_h; + const int valid_h_end = output_h - valid_h_start; + const int valid_h = valid_h_end - valid_h_start; + const int valid_w_start = padding_w; + const int valid_w_end = output_w - valid_w_start; + const int valid_w = valid_w_end - valid_w_start; #pragma omp parallel for for (int g = 0; g < input.dims()[1]; ++g) { @@ -643,21 +647,22 @@ void DepthwiseConv3x3S2(const framework::Tensor &input, const float *input_data = input.data(); const float *filter_data = filter.data(); float *out_data = output->mutable_data(); - int input_h = input.dims()[2]; - int input_w = input.dims()[3]; - int output_h = output->dims()[2]; - int output_w = output->dims()[3]; - int padding_h = paddings[0]; - int padding_w = paddings[1]; - int image_size = input_h * input_w; - int out_image_size = output_h * output_w; - int valid_h_start = (padding_h + 1) / 2; - int valid_h_end = (input_h + padding_h - 1) / 2; - int valid_h = valid_h_end - valid_h_start; - int valid_w_start = (padding_w + 1) / 2; - int valid_w_end = (input_w + padding_w - 1) / 2; - int valid_w = valid_w_end - valid_w_start; - int input_w_start = 2 * valid_w_start - padding_w; + + const int input_h = input.dims()[2]; + const int input_w = input.dims()[3]; + const int output_h = output->dims()[2]; + const int output_w = output->dims()[3]; + const int padding_h = paddings[0]; + const int padding_w = paddings[1]; + const int image_size = input_h * input_w; + const int out_image_size = output_h * output_w; + const int valid_h_start = (padding_h + 1) / 2; + const int valid_h_end = (input_h + padding_h - 1) / 2; + const int valid_h = valid_h_end - valid_h_start; + const int valid_w_start = (padding_w + 1) / 2; + const int valid_w_end = (input_w + padding_w - 1) / 2; + const int valid_w = valid_w_end - valid_w_start; + const int input_w_start = 2 * valid_w_start - padding_w; #pragma omp parallel for for (int g = 0; g < input.dims()[1]; ++g) { diff --git a/src/operators/math/depthwise_conv3x3_int8.cpp b/src/operators/math/depthwise_conv3x3_int8.cpp index 76262c76fb..b8d7939bad 100644 --- a/src/operators/math/depthwise_conv3x3_int8.cpp +++ b/src/operators/math/depthwise_conv3x3_int8.cpp @@ -69,9 +69,8 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, // border left DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) // middle - int remain_start = valid_w_start; int output_tiles = (valid_w_end - valid_w_start) / 6; - remain_start = valid_w_start + output_tiles * 6; + int remain_start = valid_w_start + output_tiles * 6; int32x4_t _sum0, _sum1; int16x8_t _y[3]; for (int w = 0; w < output_tiles * 6; w += 6) { diff --git a/src/operators/math/depthwise_conv5x5.cpp b/src/operators/math/depthwise_conv5x5.cpp index 99ddfc9249..99c08c185c 100644 --- a/src/operators/math/depthwise_conv5x5.cpp +++ b/src/operators/math/depthwise_conv5x5.cpp @@ -16,6 +16,7 @@ limitations under the License. */ #include "operators/math/depthwise_conv5x5.h" #include +#include namespace paddle_mobile { namespace operators { @@ -48,7 +49,7 @@ inline void Depth5x5NormalRowLoadInput<2>(const float *input, float32x4_t *y) { y[4] = vextq_f32(y[0], y[0], 2); } -#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \ +#define DEPTHWISE_CONV5X5_NORMAL_BORDER(start, end) \ for (int w = start; w < end; ++w) { \ const int w_in_start = -padding_w + w * Stride_w; \ const int w_in_end = w_in_start + 5; \ @@ -77,10 +78,14 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter, const int h_end = h_in_end < input_h ? h_in_end : input_h; int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; - int valid_w_end = output_w - valid_w_start; + int valid_w_end = (input_w + padding_w - 5) / Stride_w + 1; + if (valid_w_end < valid_w_start) { + valid_w_end = valid_w_start; + } float *output_ptr = output + h_output * output_w; + // border left - DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) + DEPTHWISE_CONV5X5_NORMAL_BORDER(0, valid_w_start) // middle int output_tiles = (valid_w_end - valid_w_start) >> 2; float32x4_t _sum, _x[5]; @@ -120,20 +125,18 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter, _sum = vmlaq_lane_f32(_sum, _x[4], vget_high_f32(ker[index]), 1); } switch (remain) { - case 1: - vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0); - break; + case 3: + vst1q_lane_f32(output_ptr0 + 2, _sum, 2); case 2: vst1_f32(output_ptr0, vget_low_f32(_sum)); break; - case 3: - vst1_f32(output_ptr0, vget_low_f32(_sum)); - vst1_lane_f32(output_ptr0 + 2, vget_high_f32(_sum), 0); + case 1: + vst1q_lane_f32(output_ptr0, _sum, 0); break; } } // border right - DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w) + DEPTHWISE_CONV5X5_NORMAL_BORDER(valid_w_end, output_w) } template <> @@ -161,7 +164,7 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, const int valid_w = valid_w_end - valid_w_start; #pragma omp parallel for - for (int g = 0; g < input.dims()[1]; ++g) { + for (int g = 0; g < output->dims()[1]; ++g) { const float *input_ptr = input_data + g * image_size; const float *filter_ptr = filter_data + g * 25; float *output_ptr = out_data + g * out_image_size; diff --git a/src/operators/math/gemm/cblas.cc b/src/operators/math/gemm/cblas.cc index 0cda7197f7..6dc04d1b4e 100644 --- a/src/operators/math/gemm/cblas.cc +++ b/src/operators/math/gemm/cblas.cc @@ -27,12 +27,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N, const int K, const float alpha, const float *A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc) { - // if (N == 1) { - // return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C); - // } - - GemmExecutor exec(transA, transB, M, N, K); - exec(alpha, A, lda, B, ldb, beta, C, ldc); + if (N == 1) { + return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C); + } else if (M == 1) { + return cblas_sgemv(!transB, N, K, alpha, B, ldb, A, beta, C); + } else { + GemmExecutor exec(transA, transB, M, N, K); + exec(alpha, A, lda, B, ldb, beta, C, ldc); + } } void cblas_sgemv(const bool trans, const int M, const int N, const float alpha, diff --git a/src/operators/math/gemm/executor.h b/src/operators/math/gemm/executor.h index c629471c7c..ddbed3dbdf 100644 --- a/src/operators/math/gemm/executor.h +++ b/src/operators/math/gemm/executor.h @@ -239,11 +239,11 @@ class GemvExecutor : public Executor { public: GemvExecutor(const bool transA, const int M, const int N) - : Executor(), M_(M), N_(N) {} + : Executor(), M_(M), N_(N), trans_(transA) {} void operator()(const float alpha, const Itype *A, const int lda, const Itype *B, const float beta, Otype *C) { - // strategy_.kernel(); + strategy_.kernel(trans_, M_, N_, alpha, A, lda, B, beta, C); } virtual ~GemvExecutor() {} @@ -251,6 +251,7 @@ class GemvExecutor : public Executor { private: const unsigned int M_; const unsigned int N_; + const bool trans_; Strategy strategy_; }; diff --git a/src/operators/math/gemm/gemm_kernel.h b/src/operators/math/gemm/gemm_kernel.h index d96fd43f52..11a9ec008f 100644 --- a/src/operators/math/gemm/gemm_kernel.h +++ b/src/operators/math/gemm/gemm_kernel.h @@ -17,6 +17,7 @@ limitations under the License. */ #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include +#include "operators/math/math.h" namespace paddle_mobile { namespace operators { @@ -325,6 +326,199 @@ void sgemm_6x8(const float *lhs, const float *rhs, const int k, float *output, } #endif // __aarch64__ +void sgemv_notrans_mx1(const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, + const float beta, float *C) { + uint32_t mask[4] = {0, 1, 2, 3}; + int remain_n = N & 0x3; + uint32x4_t vmask = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n)); + float32x4_t _sum0, _sum1, _sum2, _sum3; + float32x4_t _valpha = vdupq_n_f32(alpha); + + #pragma omp parallel for + for (int m = 0; m < M - 3; m += 4) { + const float *in0 = A + m * lda; + const float *in1 = in0 + lda; + const float *in2 = in1 + lda; + const float *in3 = in2 + lda; + float *output = C + m; + _sum0 = vdupq_n_f32(0.f); + _sum1 = vdupq_n_f32(0.f); + _sum2 = vdupq_n_f32(0.f); + _sum3 = vdupq_n_f32(0.f); + + int n = 0; + for (; n < N - 3; n += 4) { + float32x4_t _r0 = vld1q_f32(in0 + n); + float32x4_t _r1 = vld1q_f32(in1 + n); + float32x4_t _r2 = vld1q_f32(in2 + n); + float32x4_t _r3 = vld1q_f32(in3 + n); + float32x4_t _b = vld1q_f32(B + n); + _sum0 = vmlaq_f32(_sum0, _r0, _b); + _sum1 = vmlaq_f32(_sum1, _r1, _b); + _sum2 = vmlaq_f32(_sum2, _r2, _b); + _sum3 = vmlaq_f32(_sum3, _r3, _b); + } + if (n < N) { + float32x4_t _r0 = vld1q_f32(in0 + n); + float32x4_t _r1 = vld1q_f32(in1 + n); + float32x4_t _r2 = vld1q_f32(in2 + n); + float32x4_t _r3 = vld1q_f32(in3 + n); + float32x4_t _b = vld1q_f32(B + n); + _r0 = vandq_f32_u32(_r0, vmask); + _r1 = vandq_f32_u32(_r1, vmask); + _r2 = vandq_f32_u32(_r2, vmask); + _r3 = vandq_f32_u32(_r3, vmask); + _b = vandq_f32_u32(_b, vmask); + _sum0 = vmlaq_f32(_sum0, _r0, _b); + _sum1 = vmlaq_f32(_sum1, _r1, _b); + _sum2 = vmlaq_f32(_sum2, _r2, _b); + _sum3 = vmlaq_f32(_sum3, _r3, _b); + } + _sum0 = vpaddq_f32(_sum0, _sum1); + _sum2 = vpaddq_f32(_sum2, _sum3); + _sum0 = vpaddq_f32(_sum0, _sum2); + _sum0 = vmulq_f32(_sum0, _valpha); + if (beta != 0.f) { + _sum2 = vmulq_n_f32(vld1q_f32(output), beta); + _sum0 = vaddq_f32(_sum0, _sum2); + } + // restore + vst1q_f32(output, _sum0); + } + // remain m + for (int m = (M & 0xfffc); m < M; ++m) { + const float *in0 = A + m * lda; + float *output = C + m; + _sum0 = vdupq_n_f32(0.f); + + int n = 0; + for (; n < N - 3; n += 4) { + float32x4_t _r0 = vld1q_f32(in0 + n); + float32x4_t _b = vld1q_f32(B + n); + _sum0 = vmlaq_f32(_sum0, _r0, _b); + } + if (n < N) { + float32x4_t _r0 = vld1q_f32(in0 + n); + float32x4_t _b = vld1q_f32(B + n); + _r0 = vandq_f32_u32(_r0, vmask); + _b = vandq_f32_u32(_b, vmask); + _sum0 = vmlaq_f32(_sum0, _r0, _b); + } + _sum0 = vpaddq_f32(_sum0, _sum0); + _sum0 = vmulq_f32(_sum0, _valpha); + if (beta != 0.f) { + _sum2 = vmulq_n_f32(vld1q_f32(output), beta); + _sum0 = vpaddq_f32(_sum0, _sum2); + } + // restore + *output = vgetq_lane_f32(_sum0, 0) + vgetq_lane_f32(_sum0, 1); + } +} + +void sgemv_trans_mx1(const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, + const float beta, float *C) { + float32x4_t _valpha = vdupq_n_f32(alpha); + if (beta == 0.f) { + float32x4_t vzero = vdupq_n_f32(0.f); + for (int m = 0; m < M - 3; m += 4) { + vst1q_f32(C + m, vzero); + } + for (int m = (M & 0xfffc); m < M; ++m) { + C[m] = 0.f; + } + } else { + float32x4_t vbeta = vdupq_n_f32(beta); + for (int m = 0; m < M - 3; m += 4) { + float32x4_t _vc = vld1q_f32(C + m); + _vc = vmulq_f32(_vc, vbeta); + vst1q_f32(C + m, _vc); + } + for (int m = (M & 0xfffc); m < M; ++m) { + C[m] *= beta; + } + } + + #pragma omp parallel for + for (int n = 0; n < N - 3; n += 4) { + const float *in0 = A + n * lda; + const float *in1 = in0 + lda; + const float *in2 = in1 + lda; + const float *in3 = in2 + lda; + float32x4_t _b = vld1q_f32(B + n); + float32x4_t _sum0; + int m = 0; + for (; m < M - 3; m += 4) { + float32x4_t _r0 = vld1q_f32(in0 + m); + float32x4_t _r1 = vld1q_f32(in1 + m); + float32x4_t _r2 = vld1q_f32(in2 + m); + float32x4_t _r3 = vld1q_f32(in3 + m); + float32x4_t _vc = vld1q_f32(C + m); + + _sum0 = vmulq_lane_f32(_r0, vget_low_f32(_b), 0); + _sum0 = vmlaq_lane_f32(_sum0, _r1, vget_low_f32(_b), 1); + _sum0 = vmlaq_lane_f32(_sum0, _r2, vget_high_f32(_b), 0); + _sum0 = vmlaq_lane_f32(_sum0, _r3, vget_high_f32(_b), 1); + _sum0 = vmulq_f32(_sum0, _valpha); + _sum0 = vaddq_f32(_sum0, _vc); + vst1q_f32(C + m, _sum0); + } + if (m < M) { + float32x4_t _r0 = vld1q_f32(in0 + m); + float32x4_t _r1 = vld1q_f32(in1 + m); + float32x4_t _r2 = vld1q_f32(in2 + m); + float32x4_t _r3 = vld1q_f32(in3 + m); + float32x4_t _vc = vld1q_f32(C + m); + + _sum0 = vmulq_lane_f32(_r0, vget_low_f32(_b), 0); + _sum0 = vmlaq_lane_f32(_sum0, _r1, vget_low_f32(_b), 1); + _sum0 = vmlaq_lane_f32(_sum0, _r2, vget_high_f32(_b), 0); + _sum0 = vmlaq_lane_f32(_sum0, _r3, vget_high_f32(_b), 1); + _sum0 = vmulq_f32(_sum0, _valpha); + _sum0 = vaddq_f32(_sum0, _vc); + switch (M - m) { + case 3: + vst1q_lane_f32(C + m + 2, _sum0, 2); + case 2: + vst1_f32(C + m, vget_low_f32(_sum0)); + break; + case 1: + vst1q_lane_f32(C + m, _sum0, 0); + break; + } + } + } + // remain n + for (int n = (N & 0xfffc); n < N; ++n) { + const float *in0 = A + n * lda; + float32x4_t _b = vld1q_dup_f32(B + n); + float32x4_t _sum0; + int m = 0; + for (; m < M - 3; m += 4) { + float32x4_t _r0 = vld1q_f32(in0 + m); + _sum0 = vld1q_f32(C + m); + _r0 = vmulq_f32(_r0, _b); + _r0 = vmulq_f32(_valpha, _r0); + _sum0 = vaddq_f32(_sum0, _r0); + vst1q_f32(C + m, _sum0); + } + for (; m < M; ++m) { + C[m] += alpha * (in0[m] * B[n]); + } + } +} + +void sgemv_mx1(const bool trans, const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, const float beta, + float *C) { + if (trans) { + sgemv_trans_mx1(M, N, alpha, A, lda, B, beta, C); + } else { + sgemv_notrans_mx1(M, N, alpha, A, lda, B, beta, C); + } +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm/pack_kernel.h b/src/operators/math/gemm/pack_kernel.h index 598bf3248d..b1f6a9d35e 100644 --- a/src/operators/math/gemm/pack_kernel.h +++ b/src/operators/math/gemm/pack_kernel.h @@ -20,15 +20,12 @@ limitations under the License. */ #ifdef _OPENMP #include #endif +#include "operators/math/math.h" namespace paddle_mobile { namespace operators { namespace math { -inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) { - return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); -} - void pack_lhs_6r(const int m, const int k, const float *A, const int lda, float *output, const bool unroll) { uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5}; @@ -218,15 +215,21 @@ void pack_lhs_6r(const int m, const int k, const float *A, const int lda, vst1q_f32(out_ptr + 18, _d3); vst1_f32(out_ptr + 22, vget_high_f32(_d5)); + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + a4 += 4; + a5 += 4; out_ptr += 24; #else asm volatile( - "vld1.32 {d0-d1}, [%[a0]] \n" - "vld1.32 {d2-d3}, [%[a1]] \n" - "vld1.32 {d4-d5}, [%[a2]] \n" - "vld1.32 {d6-d7}, [%[a3]] \n" - "vld1.32 {d8-d9}, [%[a4]] \n" - "vld1.32 {d10-d11}, [%[a5]] \n" + "vld1.32 {d0-d1}, [%[a0]]! \n" + "vld1.32 {d2-d3}, [%[a1]]! \n" + "vld1.32 {d4-d5}, [%[a2]]! \n" + "vld1.32 {d6-d7}, [%[a3]]! \n" + "vld1.32 {d8-d9}, [%[a4]]! \n" + "vld1.32 {d10-d11}, [%[a5]]! \n" "vtrn.32 q0, q1 \n" "vtrn.32 q2, q3 \n" "vtrn.32 q4, q5 \n" @@ -255,6 +258,20 @@ void pack_lhs_6r(const int m, const int k, const float *A, const int lda, #endif } // remain k + switch (remain_m) { + case 1: + a1 = zerobuff; + case 2: + a2 = zerobuff; + case 3: + a3 = zerobuff; + case 4: + a4 = zerobuff; + case 5: + a5 = zerobuff; + default: + break; + } for (; lk < k; ++lk) { *out_ptr++ = *a0++; *out_ptr++ = *a1++; diff --git a/src/operators/math/gemm/strategy.h b/src/operators/math/gemm/strategy.h index 51417a3b4b..11e24fb1c3 100644 --- a/src/operators/math/gemm/strategy.h +++ b/src/operators/math/gemm/strategy.h @@ -88,19 +88,12 @@ struct SgemvStrategy { typedef float Itype; typedef float Otype; - typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *, - const int); - kern_type kernel; - - static int out_width() { return 1; } + typedef void (*kernelFunc)(const bool, const int, const int, const float, + const Itype *, const int, const Itype *, + const float, Otype *); + kernelFunc kernel; - static int out_height() { -#if __aarch64__ - return 12; -#else - return 6; -#endif - } + SgemvStrategy() { kernel = sgemv_mx1; } }; struct I8o32gemvStrategy { diff --git a/src/operators/math/math.h b/src/operators/math/math.h index 3f9245351d..8ff5019e31 100644 --- a/src/operators/math/math.h +++ b/src/operators/math/math.h @@ -327,4 +327,16 @@ static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { return exp_ps(vmulq_f32(b, log_ps(a))); } +#ifndef __aarch64__ +inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) { + float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0)); + float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1)); + return vcombine_f32(sum0, sum1); +} +#endif + +inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) { + return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); +} + #endif // __ARM_NEON__ diff --git a/test/common/test_gemm_accuracy.cpp b/test/common/test_gemm_accuracy.cpp index 93cea2fd36..174459d3f5 100644 --- a/test/common/test_gemm_accuracy.cpp +++ b/test/common/test_gemm_accuracy.cpp @@ -18,7 +18,7 @@ limitations under the License. */ #include "../test_helper.h" #include "common/log.h" #include "memory/t_malloc.h" -#include "operators/math/gemm.h" +#include "operators/math/gemm/cblas.h" #define a(i, j) a[(i)*lda + (j)] #define b(i, j) b[(i)*ldb + (j)] @@ -36,10 +36,12 @@ void print_matrix(int m, int n, int ldc, float *c) { std::cout << std::endl; } -int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { - int lda = k; - int ldb = n; - int ldc = n; +int do_sgemm(int m, int n, int k, int pr) { + const float alpha = 1.f; + const float beta = 0.f; + const int lda = k; + const int ldb = n; + const int ldc = n; float *a = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * k)); @@ -49,24 +51,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); float *c1 = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); - float *scale = - static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m)); - float *bias = - static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m)); - srand(unsigned(time(0))); + std::mt19937 rng(111); + std::uniform_real_distribution uniform_dist(0, 1); + const float lower = -10.f; + const float upper = 10.f; + for (int i = 0; i < m * k; ++i) { - a[i] = t1 + rand() % t2; + a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } for (int i = 0; i < k * n; ++i) { - b[i] = t1 + rand() % t2; - } - for (int i = 0; i < m; ++i) { - scale[i] = t1 + rand() % t2; - } - for (int i = 0; i < m; ++i) { - bias[i] = t1 + rand() % t2; + b[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } + memcpy(c, c1, sizeof(float) * m * n); for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { @@ -74,25 +71,20 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { for (int p = 0; p < k; p++) { r += a(i, p) * b(p, j); } - r *= scale[i]; - r += bias[i]; - if (relu && (r < 0)) { - r = 0; - } - c1(i, j) = r; + c1(i, j) = alpha * r; } } - paddle_mobile::operators::math::Gemm gemm; - gemm.SgemmWithBn(m, n, k, 1, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias, - nullptr); - int eq = 0; - int neq = 0; + std::cout << "run cblas_sgemm..." << std::endl; + paddle_mobile::operators::math::cblas_sgemm(false, false, m, n, k, alpha, a, + lda, b, ldb, 0.f, c, ldc); + + std::cout << "compare results..." << std::endl; for (int i = 0; i < m * n; ++i) { - if (static_cast(c[i]) == static_cast(c1[i])) { - ++eq; - } else { - ++neq; + if (abs(c[i] - c1[i]) >= 1e-2) { + std::cout << "c[" << i << "] != c1[" << i << "]: " << c[i] << " vs " + << c1[i] << std::endl; + exit(1); } } @@ -107,33 +99,36 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { print_matrix(m, n, ldc, c1); } - std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu - << " eq=" << eq << " neq=" << neq << std::endl; - - PADDLE_MOBILE_ENFORCE(neq == 0, "The execution of do_sgemm is failed!"); - paddle_mobile::memory::Free(a); paddle_mobile::memory::Free(b); paddle_mobile::memory::Free(c); paddle_mobile::memory::Free(c1); - paddle_mobile::memory::Free(scale); - paddle_mobile::memory::Free(bias); return 0; } -int main() { - do_sgemm(9, 9, 9, true, 10, 10, 10); - do_sgemm(10, 6, 12, false, 10, 10, 0); - do_sgemm(512, 256, 384, false, 10, 10, 0); - do_sgemm(1366, 768, 256, false, 10, 10, 0); - do_sgemm(1255, 755, 333, false, 10, 10, 0); - do_sgemm(555, 777, 999, false, 10, 10, 0); - - do_sgemm(10, 6, 12, true, -4, 10, 0); - do_sgemm(512, 256, 384, true, -4, 10, 0); - do_sgemm(1366, 768, 256, true, -4, 10, 0); - do_sgemm(1255, 755, 333, true, -4, 10, 0); - do_sgemm(555, 777, 999, true, -4, 10, 0); +int main(int argc, char *argv[]) { + do_sgemm(1, 1, 1, 1); + + do_sgemm(9, 9, 1, 1); + do_sgemm(999, 99, 1, 0); + do_sgemm(999, 1, 1, 0); + do_sgemm(1, 9, 9, 1); + do_sgemm(1, 99, 999, 0); + do_sgemm(1, 1, 999, 0); + + do_sgemm(9, 9, 9, 1); + do_sgemm(10, 6, 12, 1); + do_sgemm(512, 256, 384, 0); + do_sgemm(1366, 768, 256, 0); + do_sgemm(1255, 755, 333, 0); + do_sgemm(555, 777, 999, 0); + + do_sgemm(10, 6, 12, 1); + do_sgemm(512, 256, 384, 0); + do_sgemm(1366, 768, 256, 0); + do_sgemm(1255, 755, 333, 0); + do_sgemm(555, 777, 999, 0); + return 0; } -- GitLab