提交 cd1b6c08 编写于 作者: H hjchen2

Optimize vector-matrix and matrix-vector multiply

上级 1d475a2c
......@@ -18,8 +18,6 @@ limitations under the License. */
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <iostream>
namespace paddle_mobile {
namespace operators {
......
......@@ -65,14 +65,12 @@ void DWConvBNReluKernel<CPU, float>::Compute(
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
DepthwiseConv3x3<float, float>(param);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
......
......@@ -190,19 +190,23 @@ void DepthwiseConv3x3(const ConvParam<CPU> &param) {
Tensor *output = param.Output();
output->mutable_data<Otype>();
if (strides[0] == 1) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
if (strides[0] == 1) {
math::DepthwiseConv3x3S1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
}
} else if (strides[0] == 2) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
math::DepthwiseConv3x3S2<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
}
} else {
GemmConv<Itype, Otype>(param);
}
}
}
template <typename Itype, typename Otype>
......@@ -215,16 +219,16 @@ void DepthwiseConv5x5(const ConvParam<CPU> &param) {
Tensor *output = param.Output();
output->mutable_data<Otype>();
// if (strides[0] == 1) {
// for (int i = 0; i < batch_size; i++) {
// Tensor in_batch = input->Slice(i, i + 1);
// Tensor out_batch = output->Slice(i, i + 1);
// math::DepthwiseConv5x5S1<Itype, Otype>(in_batch, *filter, paddings,
// &out_batch);
// }
// } else {
if (strides[0] == 1) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
math::DepthwiseConv5x5S1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
}
} else {
GemmConv<Itype, Otype>(param);
// }
}
}
template void GemmConv<float, float>(const ConvParam<CPU> &param);
......
......@@ -73,8 +73,11 @@ inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter,
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h;
const int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
const int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1;
int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1;
if (valid_w_end < valid_w_start) {
valid_w_end = valid_w_start;
}
// const int valid_w_end = output_w - valid_w_start;
float *output_ptr = output + h_output * output_w;
// border left
......@@ -120,7 +123,7 @@ inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter,
vst1_f32(output_ptr0, vget_low_f32(_sum));
break;
case 1:
vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0);
vst1q_lane_f32(output_ptr0, _sum, 0);
break;
}
}
......@@ -136,20 +139,21 @@ void DepthwiseConv3x3S1<float, float>(const framework::Tensor &input,
const float *input_data = input.data<float>();
const float *filter_data = filter.data<float>();
float *out_data = output->mutable_data<float>();
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = padding_h;
int valid_h_end = output_h - valid_h_start;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start;
const int input_h = input.dims()[2];
const int input_w = input.dims()[3];
const int output_h = output->dims()[2];
const int output_w = output->dims()[3];
const int padding_h = paddings[0];
const int padding_w = paddings[1];
const int image_size = input_h * input_w;
const int out_image_size = output_h * output_w;
const int valid_h_start = padding_h;
const int valid_h_end = output_h - valid_h_start;
const int valid_h = valid_h_end - valid_h_start;
const int valid_w_start = padding_w;
const int valid_w_end = output_w - valid_w_start;
const int valid_w = valid_w_end - valid_w_start;
#pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) {
......@@ -643,21 +647,22 @@ void DepthwiseConv3x3S2<float, float>(const framework::Tensor &input,
const float *input_data = input.data<float>();
const float *filter_data = filter.data<float>();
float *out_data = output->mutable_data<float>();
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = (padding_h + 1) / 2;
int valid_h_end = (input_h + padding_h - 1) / 2;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = (padding_w + 1) / 2;
int valid_w_end = (input_w + padding_w - 1) / 2;
int valid_w = valid_w_end - valid_w_start;
int input_w_start = 2 * valid_w_start - padding_w;
const int input_h = input.dims()[2];
const int input_w = input.dims()[3];
const int output_h = output->dims()[2];
const int output_w = output->dims()[3];
const int padding_h = paddings[0];
const int padding_w = paddings[1];
const int image_size = input_h * input_w;
const int out_image_size = output_h * output_w;
const int valid_h_start = (padding_h + 1) / 2;
const int valid_h_end = (input_h + padding_h - 1) / 2;
const int valid_h = valid_h_end - valid_h_start;
const int valid_w_start = (padding_w + 1) / 2;
const int valid_w_end = (input_w + padding_w - 1) / 2;
const int valid_w = valid_w_end - valid_w_start;
const int input_w_start = 2 * valid_w_start - padding_w;
#pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) {
......
......@@ -69,9 +69,8 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
// border left
DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start)
// middle
int remain_start = valid_w_start;
int output_tiles = (valid_w_end - valid_w_start) / 6;
remain_start = valid_w_start + output_tiles * 6;
int remain_start = valid_w_start + output_tiles * 6;
int32x4_t _sum0, _sum1;
int16x8_t _y[3];
for (int w = 0; w < output_tiles * 6; w += 6) {
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "operators/math/depthwise_conv5x5.h"
#include <arm_neon.h>
#include <iostream>
namespace paddle_mobile {
namespace operators {
......@@ -48,7 +49,7 @@ inline void Depth5x5NormalRowLoadInput<2>(const float *input, float32x4_t *y) {
y[4] = vextq_f32(y[0], y[0], 2);
}
#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \
#define DEPTHWISE_CONV5X5_NORMAL_BORDER(start, end) \
for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride_w; \
const int w_in_end = w_in_start + 5; \
......@@ -77,10 +78,14 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter,
const int h_end = h_in_end < input_h ? h_in_end : input_h;
int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
int valid_w_end = output_w - valid_w_start;
int valid_w_end = (input_w + padding_w - 5) / Stride_w + 1;
if (valid_w_end < valid_w_start) {
valid_w_end = valid_w_start;
}
float *output_ptr = output + h_output * output_w;
// border left
DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start)
DEPTHWISE_CONV5X5_NORMAL_BORDER(0, valid_w_start)
// middle
int output_tiles = (valid_w_end - valid_w_start) >> 2;
float32x4_t _sum, _x[5];
......@@ -120,20 +125,18 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter,
_sum = vmlaq_lane_f32(_sum, _x[4], vget_high_f32(ker[index]), 1);
}
switch (remain) {
case 1:
vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0);
break;
case 3:
vst1q_lane_f32(output_ptr0 + 2, _sum, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_sum));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(_sum));
vst1_lane_f32(output_ptr0 + 2, vget_high_f32(_sum), 0);
case 1:
vst1q_lane_f32(output_ptr0, _sum, 0);
break;
}
}
// border right
DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w)
DEPTHWISE_CONV5X5_NORMAL_BORDER(valid_w_end, output_w)
}
template <>
......@@ -161,7 +164,7 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
const int valid_w = valid_w_end - valid_w_start;
#pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) {
for (int g = 0; g < output->dims()[1]; ++g) {
const float *input_ptr = input_data + g * image_size;
const float *filter_ptr = filter_data + g * 25;
float *output_ptr = out_data + g * out_image_size;
......
......@@ -27,12 +27,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
const int K, const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta, float *C,
const int ldc) {
// if (N == 1) {
// return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
// }
if (N == 1) {
return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
} else if (M == 1) {
return cblas_sgemv(!transB, N, K, alpha, B, ldb, A, beta, C);
} else {
GemmExecutor<SgemmStrategy> exec(transA, transB, M, N, K);
exec(alpha, A, lda, B, ldb, beta, C, ldc);
}
}
void cblas_sgemv(const bool trans, const int M, const int N, const float alpha,
......
......@@ -239,11 +239,11 @@ class GemvExecutor : public Executor {
public:
GemvExecutor(const bool transA, const int M, const int N)
: Executor(), M_(M), N_(N) {}
: Executor(), M_(M), N_(N), trans_(transA) {}
void operator()(const float alpha, const Itype *A, const int lda,
const Itype *B, const float beta, Otype *C) {
// strategy_.kernel();
strategy_.kernel(trans_, M_, N_, alpha, A, lda, B, beta, C);
}
virtual ~GemvExecutor() {}
......@@ -251,6 +251,7 @@ class GemvExecutor : public Executor {
private:
const unsigned int M_;
const unsigned int N_;
const bool trans_;
Strategy strategy_;
};
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#include "operators/math/math.h"
namespace paddle_mobile {
namespace operators {
......@@ -325,6 +326,199 @@ void sgemm_6x8(const float *lhs, const float *rhs, const int k, float *output,
}
#endif // __aarch64__
void sgemv_notrans_mx1(const int M, const int N, const float alpha,
const float *A, const int lda, const float *B,
const float beta, float *C) {
uint32_t mask[4] = {0, 1, 2, 3};
int remain_n = N & 0x3;
uint32x4_t vmask = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n));
float32x4_t _sum0, _sum1, _sum2, _sum3;
float32x4_t _valpha = vdupq_n_f32(alpha);
#pragma omp parallel for
for (int m = 0; m < M - 3; m += 4) {
const float *in0 = A + m * lda;
const float *in1 = in0 + lda;
const float *in2 = in1 + lda;
const float *in3 = in2 + lda;
float *output = C + m;
_sum0 = vdupq_n_f32(0.f);
_sum1 = vdupq_n_f32(0.f);
_sum2 = vdupq_n_f32(0.f);
_sum3 = vdupq_n_f32(0.f);
int n = 0;
for (; n < N - 3; n += 4) {
float32x4_t _r0 = vld1q_f32(in0 + n);
float32x4_t _r1 = vld1q_f32(in1 + n);
float32x4_t _r2 = vld1q_f32(in2 + n);
float32x4_t _r3 = vld1q_f32(in3 + n);
float32x4_t _b = vld1q_f32(B + n);
_sum0 = vmlaq_f32(_sum0, _r0, _b);
_sum1 = vmlaq_f32(_sum1, _r1, _b);
_sum2 = vmlaq_f32(_sum2, _r2, _b);
_sum3 = vmlaq_f32(_sum3, _r3, _b);
}
if (n < N) {
float32x4_t _r0 = vld1q_f32(in0 + n);
float32x4_t _r1 = vld1q_f32(in1 + n);
float32x4_t _r2 = vld1q_f32(in2 + n);
float32x4_t _r3 = vld1q_f32(in3 + n);
float32x4_t _b = vld1q_f32(B + n);
_r0 = vandq_f32_u32(_r0, vmask);
_r1 = vandq_f32_u32(_r1, vmask);
_r2 = vandq_f32_u32(_r2, vmask);
_r3 = vandq_f32_u32(_r3, vmask);
_b = vandq_f32_u32(_b, vmask);
_sum0 = vmlaq_f32(_sum0, _r0, _b);
_sum1 = vmlaq_f32(_sum1, _r1, _b);
_sum2 = vmlaq_f32(_sum2, _r2, _b);
_sum3 = vmlaq_f32(_sum3, _r3, _b);
}
_sum0 = vpaddq_f32(_sum0, _sum1);
_sum2 = vpaddq_f32(_sum2, _sum3);
_sum0 = vpaddq_f32(_sum0, _sum2);
_sum0 = vmulq_f32(_sum0, _valpha);
if (beta != 0.f) {
_sum2 = vmulq_n_f32(vld1q_f32(output), beta);
_sum0 = vaddq_f32(_sum0, _sum2);
}
// restore
vst1q_f32(output, _sum0);
}
// remain m
for (int m = (M & 0xfffc); m < M; ++m) {
const float *in0 = A + m * lda;
float *output = C + m;
_sum0 = vdupq_n_f32(0.f);
int n = 0;
for (; n < N - 3; n += 4) {
float32x4_t _r0 = vld1q_f32(in0 + n);
float32x4_t _b = vld1q_f32(B + n);
_sum0 = vmlaq_f32(_sum0, _r0, _b);
}
if (n < N) {
float32x4_t _r0 = vld1q_f32(in0 + n);
float32x4_t _b = vld1q_f32(B + n);
_r0 = vandq_f32_u32(_r0, vmask);
_b = vandq_f32_u32(_b, vmask);
_sum0 = vmlaq_f32(_sum0, _r0, _b);
}
_sum0 = vpaddq_f32(_sum0, _sum0);
_sum0 = vmulq_f32(_sum0, _valpha);
if (beta != 0.f) {
_sum2 = vmulq_n_f32(vld1q_f32(output), beta);
_sum0 = vpaddq_f32(_sum0, _sum2);
}
// restore
*output = vgetq_lane_f32(_sum0, 0) + vgetq_lane_f32(_sum0, 1);
}
}
void sgemv_trans_mx1(const int M, const int N, const float alpha,
const float *A, const int lda, const float *B,
const float beta, float *C) {
float32x4_t _valpha = vdupq_n_f32(alpha);
if (beta == 0.f) {
float32x4_t vzero = vdupq_n_f32(0.f);
for (int m = 0; m < M - 3; m += 4) {
vst1q_f32(C + m, vzero);
}
for (int m = (M & 0xfffc); m < M; ++m) {
C[m] = 0.f;
}
} else {
float32x4_t vbeta = vdupq_n_f32(beta);
for (int m = 0; m < M - 3; m += 4) {
float32x4_t _vc = vld1q_f32(C + m);
_vc = vmulq_f32(_vc, vbeta);
vst1q_f32(C + m, _vc);
}
for (int m = (M & 0xfffc); m < M; ++m) {
C[m] *= beta;
}
}
#pragma omp parallel for
for (int n = 0; n < N - 3; n += 4) {
const float *in0 = A + n * lda;
const float *in1 = in0 + lda;
const float *in2 = in1 + lda;
const float *in3 = in2 + lda;
float32x4_t _b = vld1q_f32(B + n);
float32x4_t _sum0;
int m = 0;
for (; m < M - 3; m += 4) {
float32x4_t _r0 = vld1q_f32(in0 + m);
float32x4_t _r1 = vld1q_f32(in1 + m);
float32x4_t _r2 = vld1q_f32(in2 + m);
float32x4_t _r3 = vld1q_f32(in3 + m);
float32x4_t _vc = vld1q_f32(C + m);
_sum0 = vmulq_lane_f32(_r0, vget_low_f32(_b), 0);
_sum0 = vmlaq_lane_f32(_sum0, _r1, vget_low_f32(_b), 1);
_sum0 = vmlaq_lane_f32(_sum0, _r2, vget_high_f32(_b), 0);
_sum0 = vmlaq_lane_f32(_sum0, _r3, vget_high_f32(_b), 1);
_sum0 = vmulq_f32(_sum0, _valpha);
_sum0 = vaddq_f32(_sum0, _vc);
vst1q_f32(C + m, _sum0);
}
if (m < M) {
float32x4_t _r0 = vld1q_f32(in0 + m);
float32x4_t _r1 = vld1q_f32(in1 + m);
float32x4_t _r2 = vld1q_f32(in2 + m);
float32x4_t _r3 = vld1q_f32(in3 + m);
float32x4_t _vc = vld1q_f32(C + m);
_sum0 = vmulq_lane_f32(_r0, vget_low_f32(_b), 0);
_sum0 = vmlaq_lane_f32(_sum0, _r1, vget_low_f32(_b), 1);
_sum0 = vmlaq_lane_f32(_sum0, _r2, vget_high_f32(_b), 0);
_sum0 = vmlaq_lane_f32(_sum0, _r3, vget_high_f32(_b), 1);
_sum0 = vmulq_f32(_sum0, _valpha);
_sum0 = vaddq_f32(_sum0, _vc);
switch (M - m) {
case 3:
vst1q_lane_f32(C + m + 2, _sum0, 2);
case 2:
vst1_f32(C + m, vget_low_f32(_sum0));
break;
case 1:
vst1q_lane_f32(C + m, _sum0, 0);
break;
}
}
}
// remain n
for (int n = (N & 0xfffc); n < N; ++n) {
const float *in0 = A + n * lda;
float32x4_t _b = vld1q_dup_f32(B + n);
float32x4_t _sum0;
int m = 0;
for (; m < M - 3; m += 4) {
float32x4_t _r0 = vld1q_f32(in0 + m);
_sum0 = vld1q_f32(C + m);
_r0 = vmulq_f32(_r0, _b);
_r0 = vmulq_f32(_valpha, _r0);
_sum0 = vaddq_f32(_sum0, _r0);
vst1q_f32(C + m, _sum0);
}
for (; m < M; ++m) {
C[m] += alpha * (in0[m] * B[n]);
}
}
}
void sgemv_mx1(const bool trans, const int M, const int N, const float alpha,
const float *A, const int lda, const float *B, const float beta,
float *C) {
if (trans) {
sgemv_trans_mx1(M, N, alpha, A, lda, B, beta, C);
} else {
sgemv_notrans_mx1(M, N, alpha, A, lda, B, beta, C);
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......
......@@ -20,15 +20,12 @@ limitations under the License. */
#ifdef _OPENMP
#include <omp.h>
#endif
#include "operators/math/math.h"
namespace paddle_mobile {
namespace operators {
namespace math {
inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) {
return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
}
void pack_lhs_6r(const int m, const int k, const float *A, const int lda,
float *output, const bool unroll) {
uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5};
......@@ -218,15 +215,21 @@ void pack_lhs_6r(const int m, const int k, const float *A, const int lda,
vst1q_f32(out_ptr + 18, _d3);
vst1_f32(out_ptr + 22, vget_high_f32(_d5));
a0 += 4;
a1 += 4;
a2 += 4;
a3 += 4;
a4 += 4;
a5 += 4;
out_ptr += 24;
#else
asm volatile(
"vld1.32 {d0-d1}, [%[a0]] \n"
"vld1.32 {d2-d3}, [%[a1]] \n"
"vld1.32 {d4-d5}, [%[a2]] \n"
"vld1.32 {d6-d7}, [%[a3]] \n"
"vld1.32 {d8-d9}, [%[a4]] \n"
"vld1.32 {d10-d11}, [%[a5]] \n"
"vld1.32 {d0-d1}, [%[a0]]! \n"
"vld1.32 {d2-d3}, [%[a1]]! \n"
"vld1.32 {d4-d5}, [%[a2]]! \n"
"vld1.32 {d6-d7}, [%[a3]]! \n"
"vld1.32 {d8-d9}, [%[a4]]! \n"
"vld1.32 {d10-d11}, [%[a5]]! \n"
"vtrn.32 q0, q1 \n"
"vtrn.32 q2, q3 \n"
"vtrn.32 q4, q5 \n"
......@@ -255,6 +258,20 @@ void pack_lhs_6r(const int m, const int k, const float *A, const int lda,
#endif
}
// remain k
switch (remain_m) {
case 1:
a1 = zerobuff;
case 2:
a2 = zerobuff;
case 3:
a3 = zerobuff;
case 4:
a4 = zerobuff;
case 5:
a5 = zerobuff;
default:
break;
}
for (; lk < k; ++lk) {
*out_ptr++ = *a0++;
*out_ptr++ = *a1++;
......
......@@ -88,19 +88,12 @@ struct SgemvStrategy {
typedef float Itype;
typedef float Otype;
typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *,
const int);
kern_type kernel;
static int out_width() { return 1; }
typedef void (*kernelFunc)(const bool, const int, const int, const float,
const Itype *, const int, const Itype *,
const float, Otype *);
kernelFunc kernel;
static int out_height() {
#if __aarch64__
return 12;
#else
return 6;
#endif
}
SgemvStrategy() { kernel = sgemv_mx1; }
};
struct I8o32gemvStrategy {
......
......@@ -327,4 +327,16 @@ static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) {
return exp_ps(vmulq_f32(b, log_ps(a)));
}
#ifndef __aarch64__
inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) {
float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0));
float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1));
return vcombine_f32(sum0, sum1);
}
#endif
inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) {
return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
}
#endif // __ARM_NEON__
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "../test_helper.h"
#include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm.h"
#include "operators/math/gemm/cblas.h"
#define a(i, j) a[(i)*lda + (j)]
#define b(i, j) b[(i)*ldb + (j)]
......@@ -36,10 +36,12 @@ void print_matrix(int m, int n, int ldc, float *c) {
std::cout << std::endl;
}
int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
int lda = k;
int ldb = n;
int ldc = n;
int do_sgemm(int m, int n, int k, int pr) {
const float alpha = 1.f;
const float beta = 0.f;
const int lda = k;
const int ldb = n;
const int ldc = n;
float *a =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * k));
......@@ -49,24 +51,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
float *c1 =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
float *scale =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m));
float *bias =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m));
srand(unsigned(time(0)));
std::mt19937 rng(111);
std::uniform_real_distribution<double> uniform_dist(0, 1);
const float lower = -10.f;
const float upper = 10.f;
for (int i = 0; i < m * k; ++i) {
a[i] = t1 + rand() % t2;
a[i] = static_cast<float>(uniform_dist(rng) * (upper - lower) + lower);
}
for (int i = 0; i < k * n; ++i) {
b[i] = t1 + rand() % t2;
}
for (int i = 0; i < m; ++i) {
scale[i] = t1 + rand() % t2;
}
for (int i = 0; i < m; ++i) {
bias[i] = t1 + rand() % t2;
b[i] = static_cast<float>(uniform_dist(rng) * (upper - lower) + lower);
}
memcpy(c, c1, sizeof(float) * m * n);
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
......@@ -74,25 +71,20 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
for (int p = 0; p < k; p++) {
r += a(i, p) * b(p, j);
}
r *= scale[i];
r += bias[i];
if (relu && (r < 0)) {
r = 0;
}
c1(i, j) = r;
c1(i, j) = alpha * r;
}
}
paddle_mobile::operators::math::Gemm gemm;
gemm.SgemmWithBn(m, n, k, 1, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias,
nullptr);
int eq = 0;
int neq = 0;
std::cout << "run cblas_sgemm..." << std::endl;
paddle_mobile::operators::math::cblas_sgemm(false, false, m, n, k, alpha, a,
lda, b, ldb, 0.f, c, ldc);
std::cout << "compare results..." << std::endl;
for (int i = 0; i < m * n; ++i) {
if (static_cast<int>(c[i]) == static_cast<int>(c1[i])) {
++eq;
} else {
++neq;
if (abs(c[i] - c1[i]) >= 1e-2) {
std::cout << "c[" << i << "] != c1[" << i << "]: " << c[i] << " vs "
<< c1[i] << std::endl;
exit(1);
}
}
......@@ -107,33 +99,36 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
print_matrix(m, n, ldc, c1);
}
std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu
<< " eq=" << eq << " neq=" << neq << std::endl;
PADDLE_MOBILE_ENFORCE(neq == 0, "The execution of do_sgemm is failed!");
paddle_mobile::memory::Free(a);
paddle_mobile::memory::Free(b);
paddle_mobile::memory::Free(c);
paddle_mobile::memory::Free(c1);
paddle_mobile::memory::Free(scale);
paddle_mobile::memory::Free(bias);
return 0;
}
int main() {
do_sgemm(9, 9, 9, true, 10, 10, 10);
do_sgemm(10, 6, 12, false, 10, 10, 0);
do_sgemm(512, 256, 384, false, 10, 10, 0);
do_sgemm(1366, 768, 256, false, 10, 10, 0);
do_sgemm(1255, 755, 333, false, 10, 10, 0);
do_sgemm(555, 777, 999, false, 10, 10, 0);
do_sgemm(10, 6, 12, true, -4, 10, 0);
do_sgemm(512, 256, 384, true, -4, 10, 0);
do_sgemm(1366, 768, 256, true, -4, 10, 0);
do_sgemm(1255, 755, 333, true, -4, 10, 0);
do_sgemm(555, 777, 999, true, -4, 10, 0);
int main(int argc, char *argv[]) {
do_sgemm(1, 1, 1, 1);
do_sgemm(9, 9, 1, 1);
do_sgemm(999, 99, 1, 0);
do_sgemm(999, 1, 1, 0);
do_sgemm(1, 9, 9, 1);
do_sgemm(1, 99, 999, 0);
do_sgemm(1, 1, 999, 0);
do_sgemm(9, 9, 9, 1);
do_sgemm(10, 6, 12, 1);
do_sgemm(512, 256, 384, 0);
do_sgemm(1366, 768, 256, 0);
do_sgemm(1255, 755, 333, 0);
do_sgemm(555, 777, 999, 0);
do_sgemm(10, 6, 12, 1);
do_sgemm(512, 256, 384, 0);
do_sgemm(1366, 768, 256, 0);
do_sgemm(1255, 755, 333, 0);
do_sgemm(555, 777, 999, 0);
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册