提交 f851d211 编写于 作者: Y yangfei

add gemm merge function C = A * B + bias

上级 150ccc98
...@@ -106,9 +106,9 @@ void ConvAddBasic(const FusionConvAddParam &param) { ...@@ -106,9 +106,9 @@ void ConvAddBasic(const FusionConvAddParam &param) {
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBias<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(1), false, biase_data); static_cast<float>(1), false, biase_data);
} }
} }
} }
......
...@@ -109,9 +109,9 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) { ...@@ -109,9 +109,9 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) {
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBias<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(1), true, biase_data); static_cast<float>(1), true, biase_data);
} }
} }
} }
......
...@@ -30,6 +30,7 @@ inline void ConvBasic(const ConvParam &param) { ...@@ -30,6 +30,7 @@ inline void ConvBasic(const ConvParam &param) {
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<float>(); output->mutable_data<float>();
float *bias_data = output->mutable_data<float>();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
...@@ -106,7 +107,7 @@ inline void ConvBasic(const ConvParam &param) { ...@@ -106,7 +107,7 @@ inline void ConvBasic(const ConvParam &param) {
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(0)); static_cast<float>(0), false, bias_data);
} }
} }
} }
......
...@@ -28,6 +28,7 @@ void FusionFcCompute(const FusionFcParam &param) { ...@@ -28,6 +28,7 @@ void FusionFcCompute(const FusionFcParam &param) {
int axis = param.Axis(); int axis = param.Axis();
Tensor *out = param.Out(); Tensor *out = param.Out();
auto *out_data = out->mutable_data<float>(); auto *out_data = out->mutable_data<float>();
float *bias_data = out->mutable_data<float>();
const Tensor x_matrix = const Tensor x_matrix =
input_x->dims().size() > 2 input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) ? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
...@@ -56,7 +57,7 @@ void FusionFcCompute(const FusionFcParam &param) { ...@@ -56,7 +57,7 @@ void FusionFcCompute(const FusionFcParam &param) {
// DLOG << out_data[i]; // DLOG << out_data[i];
// } // }
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1), math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1)); out, static_cast<float>(1), false, bias_data);
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) { // if (out_dim.size() != 2) {
// out->Resize(out_dim); // out->Resize(out_dim);
......
...@@ -59,6 +59,7 @@ void MulCompute(const MulParam &param) { ...@@ -59,6 +59,7 @@ void MulCompute(const MulParam &param) {
const Tensor *input_y = param.InputY(); const Tensor *input_y = param.InputY();
Tensor *out = param.Out(); Tensor *out = param.Out();
out->mutable_data<float>(); out->mutable_data<float>();
float *bias_data = out->mutable_data<float>();
const Tensor x_matrix = const Tensor x_matrix =
input_x->dims().size() > 2 input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) ? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
...@@ -72,7 +73,7 @@ void MulCompute(const MulParam &param) { ...@@ -72,7 +73,7 @@ void MulCompute(const MulParam &param) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
} }
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1), math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0)); out, static_cast<float>(0), false, bias_data);
if (out_dim.size() != 2) { if (out_dim.size() != 2) {
out->Resize(out_dim); out->Resize(out_dim);
} }
......
...@@ -2248,69 +2248,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -2248,69 +2248,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu) { const float *B, int ldb, float beta, float *C, int ldc, bool relu,
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) float *bias) {
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 0.5 * 1024 * 1024;
KC = k;
MC = L1 / (KC * sizeof(float));
NC = L2 / (KC * sizeof(float));
// make sure MC is multiple of MR, and NC is multiple of NR
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
for (int l = 0; l < KC; ++l) {
zero[l] = 0;
}
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc,
relu);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void SgemmWithBias(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster) // L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024; int L1 = 32 * 1024;
......
...@@ -128,10 +128,8 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale, ...@@ -128,10 +128,8 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu); const float *B, int ldb, float beta, float *C, int ldc, bool relu,
void SgemmWithBias(int m, int n, int k, float alpha, const float *A, int lda, float *bias);
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom // 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
......
...@@ -22,7 +22,8 @@ namespace math { ...@@ -22,7 +22,8 @@ namespace math {
template <> template <>
void matmul<float>(const framework::Tensor &matrix_a, bool trans_a, void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha, const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu) { framework::Tensor *matrix_out, float beta, bool relu,
float *bias) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -42,34 +43,7 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -42,34 +43,7 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N, Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu); beta, matrix_out->data<float>(), N, relu, bias);
}
template <>
void matmulWithBias<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, float *bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
SgemmWithBias(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, bias);
} }
template <> template <>
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma oncenki
#include <cmath> #include <cmath>
#include "framework/tensor.h" #include "framework/tensor.h"
...@@ -21,16 +21,10 @@ namespace paddle_mobile { ...@@ -21,16 +21,10 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
// matrix multiply with continuous memory
template <typename T> template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a, void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false); framework::Tensor *matrix_out, T beta, bool relu, float *bias);
template <typename T>
void matmulWithBias(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu,
float *bias);
template <typename T> template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
...@@ -49,9 +49,9 @@ int main() { ...@@ -49,9 +49,9 @@ int main() {
auto time1 = time(); auto time1 = time();
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float>(aa, false, bb, false, paddle_mobile::operators::math::matmul<float>(
static_cast<float>(1), &cc, aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
static_cast<float>(0), false); false, ccptr);
// paddle_mobile::operators::math::matmulWithBn<float>( // paddle_mobile::operators::math::matmulWithBn<float>(
// aa, false, bb, false, static_cast<float>(1), &cc, // aa, false, bb, false, static_cast<float>(1), &cc,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册