提交 fbeb3e20 编写于 作者: qnqinan's avatar qnqinan

Merge remote-tracking branch 'origin/develop' into develop

...@@ -61,7 +61,14 @@ struct PaddleMobileException : public std::exception { ...@@ -61,7 +61,14 @@ struct PaddleMobileException : public std::exception {
} }
#else #else
#define PADDLE_MOBILE_THROW_EXCEPTION(...) #define PADDLE_MOBILE_THROW_EXCEPTION(...)
#define PADDLE_MOBILE_ENFORCE(stat, ...)
#define PADDLE_MOBILE_ENFORCE(stat, ...) \
{ \
if (stat) { \
} else { \
} \
}
#endif #endif
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -31,12 +31,7 @@ void ConvAddBasic(const FusionConvAddParam &param) { ...@@ -31,12 +31,7 @@ void ConvAddBasic(const FusionConvAddParam &param) {
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
int axis = param.Axis(); int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
float *output_data = output->data<float>();
float *biase_data = bias.data<float>(); float *biase_data = bias.data<float>();
for (int k = 0; k < output->numel(); ++k) {
output_data[k] = biase_data[k];
}
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
...@@ -113,7 +108,7 @@ void ConvAddBasic(const FusionConvAddParam &param) { ...@@ -113,7 +108,7 @@ void ConvAddBasic(const FusionConvAddParam &param) {
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(1)); static_cast<float>(1), false, biase_data);
} }
} }
} }
......
...@@ -32,12 +32,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) { ...@@ -32,12 +32,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) {
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
int axis = param.Axis(); int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
float *output_data = output->data<float>();
float *biase_data = bias.data<float>(); float *biase_data = bias.data<float>();
for (int k = 0; k < output->numel(); ++k) {
output_data[k] = biase_data[k];
}
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
...@@ -115,7 +110,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) { ...@@ -115,7 +110,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) {
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(1), true); static_cast<float>(1), true, biase_data);
} }
} }
} }
......
...@@ -30,6 +30,7 @@ void FusionFcCompute(const FusionFcParam &param) { ...@@ -30,6 +30,7 @@ void FusionFcCompute(const FusionFcParam &param) {
int axis = param.Axis(); int axis = param.Axis();
Tensor *out = param.Out(); Tensor *out = param.Out();
auto *out_data = out->mutable_data<float>(); auto *out_data = out->mutable_data<float>();
float *bias_data = out->mutable_data<float>();
const Tensor x_matrix = const Tensor x_matrix =
input_x->dims().size() > 2 input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) ? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
...@@ -47,18 +48,18 @@ void FusionFcCompute(const FusionFcParam &param) { ...@@ -47,18 +48,18 @@ void FusionFcCompute(const FusionFcParam &param) {
PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0], PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0],
" out_dim.size must be 2."); " out_dim.size must be 2.");
axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis); axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis);
PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ") PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ");
int64_t classes = input_z->numel(); int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) { for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes); memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
} }
for (int i = 0; i < out->numel(); i++) { // for (int i = 0; i < out->numel(); i++) {
DLOG << out_data[i]; // DLOG << out_data[i];
} // }
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1), math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1)); out, static_cast<float>(1), false, bias_data);
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) { // if (out_dim.size() != 2) {
// out->Resize(out_dim); // out->Resize(out_dim);
......
...@@ -392,6 +392,42 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, ...@@ -392,6 +392,42 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
} }
} }
// 分块矩阵乘法
void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *bias) {
#pragma omp parallel for
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
#if __aarch64__
// AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif
}
}
if (alpha != 1) {
WriteWithAlphaBeta(mc, nc, c, C, ldc);
return;
}
if (beta == 0) {
WriteBasic(mc, nc, c, C, ldc);
return;
}
if (beta == 1 && !relu) {
WriteWithAddV1(mc, nc, c, C, ldc, bias);
return;
}
if (beta == 1 && relu) {
WriteWithAddReluV1(mc, nc, c, C, ldc, bias);
return;
}
}
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc, const float *b, float beta, float *c, float *C, int ldc,
...@@ -577,6 +613,43 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { ...@@ -577,6 +613,43 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
} }
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
...@@ -619,6 +692,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { ...@@ -619,6 +692,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
float32x4_t zero = vdupq_n_f32(0.0);
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias) { float *new_bias) {
...@@ -1448,6 +1563,44 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { ...@@ -1448,6 +1563,44 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16; int nc1 = nc / 16;
...@@ -1522,6 +1675,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { ...@@ -1522,6 +1675,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
float32x4_t zero = vdupq_n_f32(0.0);
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
float *bias) { float *bias) {
...@@ -2053,7 +2248,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -2053,7 +2248,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu) { const float *B, int ldb, float beta, float *C, int ldc, bool relu,
float *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster) // L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024; int L1 = 32 * 1024;
...@@ -2103,8 +2299,8 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -2103,8 +2299,8 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
#else #else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif #endif
InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc, InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC,
relu); &C(i, j), ldc, relu, bias + i);
} }
} }
......
...@@ -62,6 +62,9 @@ void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -62,6 +62,9 @@ void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu); float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *bias);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc, const float *b, float beta, float *c, float *C, int ldc,
...@@ -91,8 +94,13 @@ void WriteBasic(int mc, int nc, float *c, float *C, int ldc); ...@@ -91,8 +94,13 @@ void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc); void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C // C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc); void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias);
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc); void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + bias ,relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias);
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias); float *new_bias);
...@@ -120,7 +128,8 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale, ...@@ -120,7 +128,8 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu); const float *B, int ldb, float beta, float *C, int ldc, bool relu,
float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom // 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
......
...@@ -22,7 +22,8 @@ namespace math { ...@@ -22,7 +22,8 @@ namespace math {
template <> template <>
void matmul<float>(const framework::Tensor &matrix_a, bool trans_a, void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha, const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu) { framework::Tensor *matrix_out, float beta, bool relu,
float *bias) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -42,7 +43,7 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -42,7 +43,7 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N, Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu); beta, matrix_out->data<float>(), N, relu, bias);
} }
template <> template <>
......
...@@ -21,11 +21,11 @@ namespace paddle_mobile { ...@@ -21,11 +21,11 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
// matrix multiply with continuous memory
template <typename T> template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a, void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false); framework::Tensor *matrix_out, T beta, bool relu = false,
float *bias = nullptr);
template <typename T> template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
...@@ -49,9 +49,9 @@ int main() { ...@@ -49,9 +49,9 @@ int main() {
auto time1 = time(); auto time1 = time();
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float>(aa, false, bb, false, paddle_mobile::operators::math::matmul<float>(
static_cast<float>(1), &cc, aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
static_cast<float>(0), false); false, biasptr);
// paddle_mobile::operators::math::matmulWithBn<float>( // paddle_mobile::operators::math::matmulWithBn<float>(
// aa, false, bb, false, static_cast<float>(1), &cc, // aa, false, bb, false, static_cast<float>(1), &cc,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册