提交 d54f849e 编写于 作者: Z zhangyang

Merge remote-tracking branch 'upstream/develop' into develop

...@@ -26,8 +26,15 @@ Paddle-Moible是PaddlePaddle组织下的项目,是一个致力于嵌入式平 ...@@ -26,8 +26,15 @@ Paddle-Moible是PaddlePaddle组织下的项目,是一个致力于嵌入式平
- **ARM CPU** - **ARM CPU**
|mobilenet arm v7|1线程|2线程|4线程|
![](http://mms-graph.bj.bcebos.com/paddle-mobile%2F2018_07_29.png) |------------|----|-----|-----|
|麒麟960(ms)|110.586|72.474|49.833|
|||||
|mobilenetssd arm v7|1线程|2线程|4线程|
|麒麟960(ms)|224.464|142.544|96.068|
|||||
|googlenet(v1) arm v7|1线程|2线程|4线程|
|麒麟960(ms)|348.018|242.689|169.998|
arm cpu是paddle-mobile的主要支持方向,cpu的通用性一直是其优势。嵌入式深度学习,需要大量的cpu汇编实现。我们正在紧锣密鼓的编码,为的是能充分硬件的每一点加速能力。 arm cpu是paddle-mobile的主要支持方向,cpu的通用性一直是其优势。嵌入式深度学习,需要大量的cpu汇编实现。我们正在紧锣密鼓的编码,为的是能充分硬件的每一点加速能力。
arm cpu的优化工作还在进行中,现在使用了常规的cpu优化。在arm a73上paddle-mobile arm-v7现在单核运行一次mobilenet1.0是110+ms,显然这不是我们的最终目标,我们正在用大量的汇编改写,后续性能仍会有巨大提升空间, 目前只支持armv7, 未来我们也会支持armv8。 arm cpu的优化工作还在进行中,现在使用了常规的cpu优化。在arm a73上paddle-mobile arm-v7现在单核运行一次mobilenet1.0是110+ms,显然这不是我们的最终目标,我们正在用大量的汇编改写,后续性能仍会有巨大提升空间, 目前只支持armv7, 未来我们也会支持armv8。
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "cstring"
#include "io/paddle_inference_api.h" #include "io/paddle_inference_api.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -25,9 +25,9 @@ bool ElementwiseAddReluKernel<FPGA, float>::Init( ...@@ -25,9 +25,9 @@ bool ElementwiseAddReluKernel<FPGA, float>::Init(
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
Tensor *out = param->Out(); Tensor *out = param->Out();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<half>();
auto input_y_ptr = input_y->data<float>(); auto input_y_ptr = input_y->data<half>();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
fpga::EWAddArgs ewaddArgs; fpga::EWAddArgs ewaddArgs;
ewaddArgs.relu_enabled = relu_enabled; ewaddArgs.relu_enabled = relu_enabled;
......
...@@ -22,13 +22,13 @@ template <> ...@@ -22,13 +22,13 @@ template <>
bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) { bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<half>();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
auto input_y_ptr = input_y->data<float>(); auto input_y_ptr = input_y->data<float>();
const Tensor *input_z = param->InputZ(); const Tensor *input_z = param->InputZ();
auto input_z_ptr = input_z->data<float>(); auto input_z_ptr = input_z->data<float>();
Tensor *out = param->Out(); Tensor *out = param->Out();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0],
"Image channel should be equal to weight number"); "Image channel should be equal to weight number");
......
...@@ -22,13 +22,13 @@ template <> ...@@ -22,13 +22,13 @@ template <>
bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) { bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) {
bool relu_enabled = false; bool relu_enabled = false;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<half>();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
auto input_y_ptr = input_y->data<float>(); auto input_y_ptr = input_y->data<float>();
const Tensor *input_z = param->InputZ(); const Tensor *input_z = param->InputZ();
auto input_z_ptr = input_z->data<float>(); auto input_z_ptr = input_z->data<float>();
Tensor *out = param->Out(); Tensor *out = param->Out();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0],
"Image channel should be equal to weight number"); "Image channel should be equal to weight number");
......
...@@ -22,9 +22,9 @@ namespace operators { ...@@ -22,9 +22,9 @@ namespace operators {
template <> template <>
bool PoolKernel<FPGA, float>::Init(PoolParam *param) { bool PoolKernel<FPGA, float>::Init(PoolParam *param) {
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<float>(); auto input_ptr = input->data<half>();
Tensor *output = param->Output(); Tensor *output = param->Output();
auto output_ptr = output->mutable_data<float>(); auto output_ptr = output->mutable_data<half>();
vector<int> ksize = param->Ksize(); vector<int> ksize = param->Ksize();
vector<int> strides = param->Strides(); vector<int> strides = param->Strides();
vector<int> paddings = param->Paddings(); vector<int> paddings = param->Paddings();
......
...@@ -529,6 +529,252 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -529,6 +529,252 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
const float *newscale_data = new_scale->data<float>(); const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>(); const float *newbias_data = new_bias->data<float>();
const int batch_size = static_cast<int>(input->dims()[0]);
const int input_channel = static_cast<int>(input->dims()[1]);
const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
const int hxw = input_height * input_width;
const int l = input_height;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; b++) {
filter_data = filter->data<float>();
for (int c = 0; c < input_channel; c++) {
vnewbias = vdupq_n_f32(newbias_data[c]);
vnewscale = vdupq_n_f32(newscale_data[c]);
float w00 = filter_data[0];
float w01 = filter_data[1];
float w02 = filter_data[2];
float w10 = filter_data[3];
float w11 = filter_data[4];
float w12 = filter_data[5];
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[l] + w22 * input_data[l + 1];
output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1];
output_data[(l - 1) * l] =
w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1];
output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] +
w01 * input_data[(l - 2) * (l + 1) + 1] +
w10 * input_data[l * l - 2] +
w11 * input_data[l * l - 1];
output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c];
output_data[l - 1] =
output_data[l - 1] * newscale_data[c] + newbias_data[c];
output_data[(l - 1) * l] =
output_data[(l - 1) * l] * newscale_data[c] + newbias_data[c];
output_data[l * l - 1] =
output_data[l * l - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1];
output_data[(l - 1) * l] =
output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l];
output_data[l * l - 1] =
output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1];
}
for (int i = 1; i < l - 1; ++i) {
output_data[i * l] =
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1];
output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] +
w11 * input_data[i * l + l - 1] +
w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l];
output_data[i * l] =
output_data[i * l] * newscale_data[c] + newbias_data[c];
output_data[i * l + l - 1] =
output_data[i * l + l - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l];
output_data[i * l + l - 1] =
output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1];
}
}
int m;
for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr = output_data + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + m - 1);
in1 = vld1q_f32(input_data + m + 3);
in2 = vld1q_f32(input_data + input_width + m - 1);
in3 = vld1q_f32(input_data + input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
}
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 +
input_data[j + 1] * w12 +
input_data[input_width + j - 1] * w20 +
input_data[input_width + j] * w21 +
input_data[input_width + j + 1] * w22;
output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
}
}
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
float *output_ptr =
output_data + (output_height - 1) * output_width + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1);
in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3);
in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1);
in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
}
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[(output_height - 1) * input_width + j] =
input_data[(output_height - 2) * input_width + j - 1] * w00 +
input_data[(output_height - 2) * input_width + j] * w01 +
input_data[(output_height - 2) * input_width + j + 1] * w02 +
input_data[(output_height - 1) * input_width + j - 1] * w10 +
input_data[(output_height - 1) * input_width + j] * w11 +
input_data[(output_height - 1) * input_width + j + 1] * w12;
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] *
newscale_data[c] +
newbias_data[c];
if (if_relu) {
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] < 0
? 0
: output_data[(output_height - 1) * output_width + j];
}
}
#pragma omp parallel for
for (int i = 1; i < output_height - 1; i++) {
for (int m = 1; (m + 3) < output_width - 1; m = m + 4) {
float *output_ptr = output_data + i * output_width + m;
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3,
tmp4, tmp5, out0;
in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1);
in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3);
in2 = vld1q_f32(input_data + i * input_width + m - 1);
in3 = vld1q_f32(input_data + i * input_width + m + 3);
in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1);
in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
}
int m;
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[i * output_width + j] =
input_data[(i - 1) * input_width + j - 1] * w00 +
input_data[(i - 1) * input_width + j] * w01 +
input_data[(i - 1) * input_width + j + 1] * w02 +
input_data[(i)*input_width + j - 1] * w10 +
input_data[(i)*input_width + j] * w11 +
input_data[(i)*input_width + j + 1] * w12 +
input_data[(i + 1) * input_width + j - 1] * w20 +
input_data[(i + 1) * input_width + j] * w21 +
input_data[(i + 1) * input_width + j + 1] * w22;
output_data[i * output_width + j] =
newscale_data[c] * output_data[i * output_width + j] +
newbias_data[c];
if (if_relu) {
output_data[i * output_width + j] =
output_data[i * output_width + j] < 0
? 0
: output_data[i * output_width + j];
}
}
}
input_data = input_data + hxw;
output_data = output_data + hxw;
filter_data = filter_data + 9;
}
}
/*
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]); const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]); const int w = static_cast<int>(input->dims()[3]);
const int l = h; const int l = h;
...@@ -605,8 +851,8 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -605,8 +851,8 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j]; output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j];
if (if_relu) { if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l]; output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i *
output_data[i * l + l - 1] = l]; output_data[i * l + l - 1] =
output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1]; output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1];
} }
} }
...@@ -738,6 +984,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -738,6 +984,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
// mid // mid
for (int i = 0; i < l - 2; ++i) { for (int i = 0; i < l - 2; ++i) {
auto output_ptr = output_data + (i + 1) * l + 1; auto output_ptr = output_data + (i + 1) * l + 1;
input_tmp = input_data + i * l; input_tmp = input_data + i * l;
...@@ -820,6 +1067,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -820,6 +1067,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
filter_data_tmp += 9; filter_data_tmp += 9;
} }
} }
*/
#endif #endif
} }
......
...@@ -33,6 +33,14 @@ float *packedA; ...@@ -33,6 +33,14 @@ float *packedA;
float *packedB; float *packedB;
float *packedC; float *packedC;
float *zero; float *zero;
typedef void (*FnPack)(int, int, int, const float *, int, float *);
typedef void (*FnAddDot)(int, const float *, const float *, float *, int);
FnPack procPackA;
FnPack procPackB;
FnAddDot procAddDot;
/* /*
// 将A矩阵分块复制到连续内存(ColMajor) // 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
...@@ -135,30 +143,32 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, ...@@ -135,30 +143,32 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer) {
const float *a0, *a1, *a2, *a3, *a4, *a5; const int i_length = m - m_tail;
for (int i = 0; i < m - m_tail; i += MR) { for (int i = 0; i < i_length; i += MR) {
a0 = A + i * lda; const float *a0 = A + i * lda;
a1 = A + (i + 1) * lda; const float *a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda; const float *a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda; const float *a3 = A + (i + 3) * lda;
a4 = A + (i + 4) * lda; const float *a4 = A + (i + 4) * lda;
a5 = A + (i + 5) * lda; const float *a5 = A + (i + 5) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
} }
} }
if (m_tail != 0) { if (m_tail != 0) {
a0 = &A(m - m_tail, 0); const float *a0 = &A(i_length, 0);
a1 = a0 + lda; const float *a1 = a0 + lda;
a2 = a0 + 2 * lda; const float *a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda; const float *a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda; const float *a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda; const float *a5 = a0 + 5 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) { switch (m_tail) {
case 1: case 1:
a1 = zero; a1 = zero;
...@@ -175,48 +185,105 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, ...@@ -175,48 +185,105 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
break; break;
} }
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
}
}
}
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const int i_length = m - m_tail;
#pragma omp parallel for
for (int i = 0; i < i_length; i += MR) {
const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda;
const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda;
const float *a4 = A + (i + 4) * lda;
const float *a5 = A + (i + 5) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
if (m_tail != 0) {
const float *a0 = &A(i_length, 0);
const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda;
const float *a3 = a0 + 3 * lda;
const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
} }
} }
} }
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer) {
const float *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; const int i_length = m - m_tail;
for (int i = 0; i < m - m_tail; i += MR) { for (int i = 0; i < i_length; i += MR) {
a0 = A + i * lda; const float *a0 = A + i * lda;
a1 = A + (i + 1) * lda; const float *a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda; const float *a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda; const float *a3 = A + (i + 3) * lda;
a4 = A + (i + 4) * lda; const float *a4 = A + (i + 4) * lda;
a5 = A + (i + 5) * lda; const float *a5 = A + (i + 5) * lda;
a6 = A + (i + 6) * lda; const float *a6 = A + (i + 6) * lda;
a7 = A + (i + 7) * lda; const float *a7 = A + (i + 7) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
*buffer++ = *a6++; *local_buffer++ = *a6++;
*buffer++ = *a7++; *local_buffer++ = *a7++;
} }
} }
if (m_tail != 0) { if (m_tail != 0) {
a0 = &A(m - m_tail, 0); const float *a0 = &A(i_length, 0);
a1 = a0 + lda; const float *a1 = a0 + lda;
a2 = a0 + 2 * lda; const float *a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda; const float *a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda; const float *a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda; const float *a5 = a0 + 5 * lda;
a6 = a0 + 6 * lda; const float *a6 = a0 + 6 * lda;
a7 = a0 + 7 * lda; const float *a7 = a0 + 7 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) { switch (m_tail) {
case 1: case 1:
a1 = zero; a1 = zero;
...@@ -237,14 +304,81 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, ...@@ -237,14 +304,81 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
break; break;
} }
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
*buffer++ = *a6++; *local_buffer++ = *a6++;
*buffer++ = *a7++; *local_buffer++ = *a7++;
}
}
}
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const int i_length = m - m_tail;
#pragma omp parallel for
for (int i = 0; i < i_length; i += MR) {
const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda;
const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda;
const float *a4 = A + (i + 4) * lda;
const float *a5 = A + (i + 5) * lda;
const float *a6 = A + (i + 6) * lda;
const float *a7 = A + (i + 7) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
*local_buffer++ = *a6++;
*local_buffer++ = *a7++;
}
}
if (m_tail != 0) {
const float *a0 = &A(i_length, 0);
const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda;
const float *a3 = a0 + 3 * lda;
const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda;
const float *a6 = a0 + 6 * lda;
const float *a7 = a0 + 7 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
case 6:
a6 = zero;
case 7:
a7 = zero;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
*local_buffer++ = *a6++;
*local_buffer++ = *a7++;
} }
} }
} }
...@@ -252,48 +386,102 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, ...@@ -252,48 +386,102 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
// 将B矩阵分块复制到连续内存(RowMajor) // 将B矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
const float *b0; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); const float *b0 = &B(i, j);
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
asm volatile( asm volatile(
"prfm pldl1keep, [%[b0]] \n\t" "prfm pldl1keep, [%[b0]] \n\t"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s}, [%[buffer]], #32 \n\t" "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1"); : "memory", "v0", "v1");
#else #else
asm volatile( asm volatile(
"pld [%[b0]] \n\t" "pld [%[b0]] \n\t"
"vld1.32 {q0, q1}, [%[b0]] \n\t" "vld1.32 {q0, q1}, [%[b0]] \n\t"
"vst1.32 {q0, q1}, [%[buffer]]! \n\t" "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "q0", "q1"); : "memory", "q0", "q1");
#endif // __aarch64__ #endif // __aarch64__
#else #else
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
#endif // __ARM_NEON #endif // __ARM_NEON
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail); const float *b0 = &B(i, j_length);
for (int j = n - n_tail; j < n; ++j) { for (int j = j_length; j < n; ++j) {
*buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int j = n; j < n + (NR - n_tail); ++j) { for (int j = n; j < j_length + NR; ++j) {
*buffer++ = 0; *local_buffer++ = 0;
}
}
}
}
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%[b0]] \n\t"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1");
#else
asm volatile(
"pld [%[b0]] \n\t"
"vld1.32 {q0, q1}, [%[b0]] \n\t"
"vst1.32 {q0, q1}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "q0", "q1");
#endif // __aarch64__
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
} }
} }
} }
...@@ -302,27 +490,60 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -302,27 +490,60 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
#if __aarch64__ #if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
const float *b0; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); const float *b0 = &B(i, j);
asm volatile( asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t" "prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s}, [%[buffer]], #48 \n\t" "st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1", "v2"); : "memory", "v0", "v1", "v2");
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail); const float *b0 = &B(i, j_length);
for (int j = n - n_tail; j < n; ++j) { for (int j = j_length; j < n; ++j) {
*buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int j = n; j < n + (NR - n_tail); ++j) { for (int j = n; j < j_length + NR; ++j) {
*buffer++ = 0; *local_buffer++ = 0;
}
}
}
}
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2");
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
} }
} }
} }
...@@ -330,27 +551,60 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -330,27 +551,60 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
const float *b0; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < n - n_tail; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); const float *b0 = &B(i, j);
asm volatile( asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t" "prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[buffer]], #64 \n\t" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1", "v2", "v3"); : "memory", "v0", "v1", "v2", "v3");
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail); const float *b0 = &B(i, j_length);
for (int j = n - n_tail; j < n; ++j) { for (int j = j_length; j < n; ++j) {
*buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int j = n; j < n + (NR - n_tail); ++j) { for (int j = n; j < j_length + NR; ++j) {
*buffer++ = 0; *local_buffer++ = 0;
}
}
}
}
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < n - n_tail; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2", "v3");
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
} }
} }
} }
...@@ -2244,6 +2498,27 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -2244,6 +2498,27 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
} }
} }
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {}
void WriteBasic(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {}
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {}
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias) {}
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias) {}
#endif // __ARM_NEON #endif // __ARM_NEON
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
...@@ -2373,6 +2648,221 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -2373,6 +2648,221 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(zero); paddle_mobile::memory::Free(zero);
} }
// 32位 float 矩阵乘法
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
int L1 = 32 * 1024;
KC = k;
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
// 补齐 B
NC = (n + NR - 1) / NR * NR;
#if __aarch64__
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_8c;
procAddDot = AddDot6x8;
#endif
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
procPackB(KC, NC, NC % NR, B, ldb, packedB);
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(float));
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
// 补齐 A
MC = (m + MR - 1) / MR * MR;
#if __aarch64__
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_8c;
procAddDot = AddDot6x8;
#endif
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
procPackA(MC, KC, MC % MR, A, lda, packedA);
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
}
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int i = 0; i < m; i += MC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int mc;
mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A);
InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, bias + i);
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int nc;
nc = s_min(n - j, NC);
float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B);
InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
int L1 = 32 * 1024;
KC = k;
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
// 补齐 B
NC = (n + NR - 1) / NR * NR;
#if __aarch64__
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_8c;
procAddDot = AddDot6x8;
#endif
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
procPackB(KC, NC, NC % NR, B, ldb, packedB);
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(float));
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
// 补齐 A
MC = (m + MR - 1) / MR * MR;
#if __aarch64__
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_8c;
procAddDot = AddDot6x8;
#endif
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
procPackA(MC, KC, MC % MR, A, lda, packedA);
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
}
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int i = 0; i < m; i += MC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int mc;
mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A);
InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0),
ldc, relu, new_scale + i, new_bias + i);
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int nc;
nc = s_min(n - j, NC);
float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B);
InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j),
ldc, relu, new_scale, new_bias);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
......
...@@ -50,6 +50,10 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, ...@@ -50,6 +50,10 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor) // 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
...@@ -58,6 +62,12 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -58,6 +62,12 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
...@@ -136,6 +146,16 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -136,6 +146,16 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias); bool relu, float *new_scale, float *new_bias);
// 32位 float 矩阵乘法(openmp 多线程版本)
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本)
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -42,8 +42,13 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -42,8 +42,13 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
int N = dim_out[1]; int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP
Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu, bias);
#else
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N, Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu, bias); beta, matrix_out->data<float>(), N, relu, bias);
#endif
} }
template <> template <>
...@@ -70,10 +75,17 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -70,10 +75,17 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
int N = dim_out[1]; int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP
SgemmWithBn_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, new_scale->data<float>() + group,
new_bias->data<float>() + group);
#else
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu, N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>() + group, new_scale->data<float>() + group,
new_bias->data<float>() + group); new_bias->data<float>() + group);
#endif
} }
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册