提交 cd30eb8a 编写于 作者: E eclipsycn 提交者: GitHub

Merge pull request #531 from smilejames/develop

Optimize Gemm: fuse add relu batchnorm op, dynamic block, add AddDot4x8, optimize memory write back.
此差异已折叠。
......@@ -19,12 +19,8 @@ limitations under the License. */
#define B(i, j) B[(i)*ldb + (j)]
#define C(i, j) C[(i)*ldc + (j)]
// 分块计算的块大小,mc 与 kc 分别对应分块计算时的 m 与 k
#define MC 128
#define KC 128
#define NC 1024
#define MR 4
#define NR 4
#define NR 8
#define s_min(i, j) ((i) < (j) ? (i) : (j))
......@@ -49,28 +45,66 @@ void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法
void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
int first_time);
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc);
// 计算一个更小的 4 * 4 的 C 矩阵分块
void AddDot4x4(int k, float alpha, const float *A, int lda, const float *B,
int ldb, float beta, float *C, int ldc, int mc, int nc);
void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
int ldb, float beta, float *C, int ldc, int mc, int nc,
bool relu);
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias);
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc);
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// 32位 float 矩阵乘法
void sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc);
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu);
void sgemm_relu(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
// 64位 double 矩阵乘法
void dgemm(int m, int n, int k, float alpha, const double *A, int lda,
......
......@@ -39,22 +39,18 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
if (relu) {
sgemm_relu(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N);
} else {
sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N);
}
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu);
}
template <>
void matmul<double>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
double alpha, framework::Tensor *matrix_out, double beta,
bool relu) {
void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out, float beta,
bool relu, framework::Tensor *new_scale,
framework::Tensor *new_bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -71,7 +67,11 @@ void matmul<double>(const framework::Tensor &matrix_a, bool trans_a,
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>(), new_bias->data<float>());
}
} // namespace math
......
......@@ -26,6 +26,12 @@ template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册