提交 150ccc98 编写于 作者: Y yangfei

add gemm merge function: C = A * B + bias

上级 a9537ca4
......@@ -31,12 +31,7 @@ void ConvAddBasic(const FusionConvAddParam &param) {
Tensor bias = *param.Bias();
int axis = param.Axis();
Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
float *output_data = output->data<float>();
float *biase_data = bias.data<float>();
for (int k = 0; k < output->numel(); ++k) {
output_data[k] = biase_data[k];
}
int groups = param.Groups();
std::vector<int> strides = param.Strides();
......@@ -111,9 +106,9 @@ void ConvAddBasic(const FusionConvAddParam &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1));
math::matmulWithBias<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), false, biase_data);
}
}
}
......
......@@ -28,12 +28,12 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) {
Tensor bias = *param.Bias();
int axis = param.Axis();
Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
// math::expand_bias(bias, axis, output->dims());
float *output_data = output->data<float>();
float *biase_data = bias.data<float>();
for (int k = 0; k < output->numel(); ++k) {
output_data[k] = biase_data[k];
}
// for (int k = 0; k < output->numel(); ++k) {
// output_data[k] = biase_data[k];
// }
int groups = param.Groups();
std::vector<int> strides = param.Strides();
......@@ -109,9 +109,9 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), true);
math::matmulWithBias<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), true, biase_data);
}
}
}
......
......@@ -45,16 +45,16 @@ void FusionFcCompute(const FusionFcParam &param) {
PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0],
" out_dim.size must be 2.");
axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis);
PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ")
PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ");
int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
}
for (int i = 0; i < out->numel(); i++) {
DLOG << out_data[i];
}
// for (int i = 0; i < out->numel(); i++) {
// DLOG << out_data[i];
// }
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1));
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
......
......@@ -373,9 +373,9 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
#endif
}
}
if (alpha != 1) {
WriteWithAlphaBeta(mc, nc, c, C, ldc);
return;
}
if (beta == 0) {
......@@ -392,6 +392,42 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
}
}
// 分块矩阵乘法
void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *bias) {
#pragma omp parallel for
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
#if __aarch64__
// AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif
}
}
if (alpha != 1) {
WriteWithAlphaBeta(mc, nc, c, C, ldc);
return;
}
if (beta == 0) {
WriteBasic(mc, nc, c, C, ldc);
return;
}
if (beta == 1 && !relu) {
WriteWithAddV1(mc, nc, c, C, ldc, bias);
return;
}
if (beta == 1 && relu) {
WriteWithAddReluV1(mc, nc, c, C, ldc, bias);
return;
}
}
// 分块矩阵乘法
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
......@@ -577,6 +613,43 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
}
}
}
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
......@@ -619,6 +692,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
}
}
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
float32x4_t zero = vdupq_n_f32(0.0);
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias) {
......@@ -1448,6 +1563,44 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
}
}
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
......@@ -1522,6 +1675,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
}
}
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
float32x4_t zero = vdupq_n_f32(0.0);
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
float *bias) {
......@@ -2113,6 +2308,68 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void SgemmWithBias(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 0.5 * 1024 * 1024;
KC = k;
MC = L1 / (KC * sizeof(float));
NC = L2 / (KC * sizeof(float));
// make sure MC is multiple of MR, and NC is multiple of NR
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
for (int l = 0; l < KC; ++l) {
zero[l] = 0;
}
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, bias + i);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
......
......@@ -62,6 +62,9 @@ void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *bias);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
......@@ -91,8 +94,13 @@ void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + bias ,relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias);
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
......@@ -121,6 +129,9 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu);
void SgemmWithBias(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
......
......@@ -44,6 +44,33 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu);
}
template <>
void matmulWithBias<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, float *bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
SgemmWithBias(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, bias);
}
template <>
void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
......
......@@ -26,6 +26,11 @@ template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false);
template <typename T>
void matmulWithBias(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu,
float *bias);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册