提交 c2f25a82 编写于 作者: Z zhangyang

Merge remote-tracking branch 'upstream/develop' into develop

...@@ -26,8 +26,15 @@ Paddle-Moible是PaddlePaddle组织下的项目,是一个致力于嵌入式平 ...@@ -26,8 +26,15 @@ Paddle-Moible是PaddlePaddle组织下的项目,是一个致力于嵌入式平
- **ARM CPU** - **ARM CPU**
|mobilenet arm v7|1线程|2线程|4线程|
![](http://mms-graph.bj.bcebos.com/paddle-mobile%2F2018_07_29.png) |------------|----|-----|-----|
|麒麟960(ms)|110.586|72.474|49.833|
|||||
|mobilenetssd arm v7|1线程|2线程|4线程|
|麒麟960(ms)|224.464|142.544|96.068|
|||||
|googlenet(v1) arm v7|1线程|2线程|4线程|
|麒麟960(ms)|348.018|242.689|169.998|
arm cpu是paddle-mobile的主要支持方向,cpu的通用性一直是其优势。嵌入式深度学习,需要大量的cpu汇编实现。我们正在紧锣密鼓的编码,为的是能充分硬件的每一点加速能力。 arm cpu是paddle-mobile的主要支持方向,cpu的通用性一直是其优势。嵌入式深度学习,需要大量的cpu汇编实现。我们正在紧锣密鼓的编码,为的是能充分硬件的每一点加速能力。
arm cpu的优化工作还在进行中,现在使用了常规的cpu优化。在arm a73上paddle-mobile arm-v7现在单核运行一次mobilenet1.0是110+ms,显然这不是我们的最终目标,我们正在用大量的汇编改写,后续性能仍会有巨大提升空间, 目前只支持armv7, 未来我们也会支持armv8。 arm cpu的优化工作还在进行中,现在使用了常规的cpu优化。在arm a73上paddle-mobile arm-v7现在单核运行一次mobilenet1.0是110+ms,显然这不是我们的最终目标,我们正在用大量的汇编改写,后续性能仍会有巨大提升空间, 目前只支持armv7, 未来我们也会支持armv8。
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "cstring"
#include "io/paddle_inference_api.h" #include "io/paddle_inference_api.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -25,9 +25,9 @@ bool ElementwiseAddReluKernel<FPGA, float>::Init( ...@@ -25,9 +25,9 @@ bool ElementwiseAddReluKernel<FPGA, float>::Init(
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
Tensor *out = param->Out(); Tensor *out = param->Out();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<half>();
auto input_y_ptr = input_y->data<float>(); auto input_y_ptr = input_y->data<half>();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
fpga::EWAddArgs ewaddArgs; fpga::EWAddArgs ewaddArgs;
ewaddArgs.relu_enabled = relu_enabled; ewaddArgs.relu_enabled = relu_enabled;
......
...@@ -22,13 +22,13 @@ template <> ...@@ -22,13 +22,13 @@ template <>
bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) { bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<half>();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
auto input_y_ptr = input_y->data<float>(); auto input_y_ptr = input_y->data<float>();
const Tensor *input_z = param->InputZ(); const Tensor *input_z = param->InputZ();
auto input_z_ptr = input_z->data<float>(); auto input_z_ptr = input_z->data<float>();
Tensor *out = param->Out(); Tensor *out = param->Out();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0],
"Image channel should be equal to weight number"); "Image channel should be equal to weight number");
......
...@@ -22,13 +22,13 @@ template <> ...@@ -22,13 +22,13 @@ template <>
bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) { bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) {
bool relu_enabled = false; bool relu_enabled = false;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<half>();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
auto input_y_ptr = input_y->data<float>(); auto input_y_ptr = input_y->data<float>();
const Tensor *input_z = param->InputZ(); const Tensor *input_z = param->InputZ();
auto input_z_ptr = input_z->data<float>(); auto input_z_ptr = input_z->data<float>();
Tensor *out = param->Out(); Tensor *out = param->Out();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0],
"Image channel should be equal to weight number"); "Image channel should be equal to weight number");
......
...@@ -22,9 +22,9 @@ namespace operators { ...@@ -22,9 +22,9 @@ namespace operators {
template <> template <>
bool PoolKernel<FPGA, float>::Init(PoolParam *param) { bool PoolKernel<FPGA, float>::Init(PoolParam *param) {
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<float>(); auto input_ptr = input->data<half>();
Tensor *output = param->Output(); Tensor *output = param->Output();
auto output_ptr = output->mutable_data<float>(); auto output_ptr = output->mutable_data<half>();
vector<int> ksize = param->Ksize(); vector<int> ksize = param->Ksize();
vector<int> strides = param->Strides(); vector<int> strides = param->Strides();
vector<int> paddings = param->Paddings(); vector<int> paddings = param->Paddings();
......
...@@ -529,42 +529,42 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -529,42 +529,42 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
const float *newscale_data = new_scale->data<float>(); const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>(); const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]); const int input_channel = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
const int hxw = input_height * input_width;
const int l = input_height;
float32x4_t vnewbias = vdupq_n_f32(0.0); float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0); float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0); float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; b++) {
const float *filter_data_tmp = filter_data; filter_data = filter->data<float>();
for (int c = 0; c < input_channel; c++) {
for (int j = 0; j < c; ++j) { vnewbias = vdupq_n_f32(newbias_data[c]);
vnewbias = vdupq_n_f32(newbias_data[j]); vnewscale = vdupq_n_f32(newscale_data[c]);
vnewscale = vdupq_n_f32(newscale_data[j]);
float w00 = filter_data[0];
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 float w01 = filter_data[1];
float w00 = filter_data_tmp[0]; float w02 = filter_data[2];
float w01 = filter_data_tmp[1]; float w10 = filter_data[3];
float w02 = filter_data_tmp[2]; float w11 = filter_data[4];
float w10 = filter_data_tmp[3]; float w12 = filter_data[5];
float w11 = filter_data_tmp[4]; float w20 = filter_data[6];
float w12 = filter_data_tmp[5]; float w21 = filter_data[7];
float w20 = filter_data_tmp[6]; float w22 = filter_data[8];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] + output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[l] + w22 * input_data[l + 1]; w21 * input_data[l] + w22 * input_data[l + 1];
output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] + w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1]; w21 * input_data[2 * l - 1];
output_data[(l - 1) * l] = output_data[(l - 1) * l] =
w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1];
...@@ -572,13 +572,13 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -572,13 +572,13 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
w01 * input_data[(l - 2) * (l + 1) + 1] + w01 * input_data[(l - 2) * (l + 1) + 1] +
w10 * input_data[l * l - 2] + w10 * input_data[l * l - 2] +
w11 * input_data[l * l - 1]; w11 * input_data[l * l - 1];
output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j]; output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c];
output_data[l - 1] = output_data[l - 1] =
output_data[l - 1] * newscale_data[j] + newbias_data[j]; output_data[l - 1] * newscale_data[c] + newbias_data[c];
output_data[(l - 1) * l] = output_data[(l - 1) * l] =
output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j]; output_data[(l - 1) * l] * newscale_data[c] + newbias_data[c];
output_data[l * l - 1] = output_data[l * l - 1] =
output_data[l * l - 1] * newscale_data[j] + newbias_data[j]; output_data[l * l - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) { if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
...@@ -593,6 +593,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -593,6 +593,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] + w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1];
output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] + w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] + w10 * input_data[i * l + l - 1 - 1] +
...@@ -600,9 +601,9 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -600,9 +601,9 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
w20 * input_data[i * l + l - 1 + l - 1] + w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l]; w21 * input_data[i * l + l - 1 + l];
output_data[i * l] = output_data[i * l] =
output_data[i * l] * newscale_data[j] + newbias_data[j]; output_data[i * l] * newscale_data[c] + newbias_data[c];
output_data[i * l + l - 1] = output_data[i * l + l - 1] =
output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j]; output_data[i * l + l - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) { if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l]; output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l];
...@@ -611,28 +612,19 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -611,28 +612,19 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
} }
// top 1 row and bottom 1 row int m;
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + l);
const float *input_tmp_end = input_tmp + (l - 2) * l;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + l);
int c_mid = l_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + l + 4);
for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr = output_data + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + m - 1);
in1 = vld1q_f32(input_data + m + 3);
in2 = vld1q_f32(input_data + input_width + m - 1);
in3 = vld1q_f32(input_data + input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1); tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2); tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1); tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2); tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10); out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11); out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12); out0 = vmlaq_n_f32(out0, tmp1, w12);
...@@ -644,182 +636,438 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -644,182 +636,438 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
out0 = vmaxq_f32(out0, vzero); out0 = vmaxq_f32(out0, vzero);
} }
vst1q_f32(output_ptr, out0); vst1q_f32(output_ptr, out0);
}
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 +
input_data[j + 1] * w12 +
input_data[input_width + j - 1] * w20 +
input_data[input_width + j] * w21 +
input_data[input_width + j + 1] * w22;
output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c];
in5 = vld1q_f32(input_tmp_end + 4); if (if_relu) {
in7 = vld1q_f32(input_tmp_end + l + 4); output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
}
}
tmp0 = vextq_f32(in4, in5, 1); for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
tmp1 = vextq_f32(in4, in5, 2); float *output_ptr =
tmp2 = vextq_f32(in6, in7, 1); output_data + (output_height - 1) * output_width + m;
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00); float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1);
in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3);
in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1);
in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02); out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10); out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) { if (if_relu) {
out0 = vmaxq_f32(out0, vzero); out0 = vmaxq_f32(out0, vzero);
} }
vst1q_f32(output_ptr + (l - 1) * l, out0); vst1q_f32(output_ptr, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
} }
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[(output_height - 1) * input_width + j] =
input_data[(output_height - 2) * input_width + j - 1] * w00 +
input_data[(output_height - 2) * input_width + j] * w01 +
input_data[(output_height - 2) * input_width + j + 1] * w02 +
input_data[(output_height - 1) * input_width + j - 1] * w10 +
input_data[(output_height - 1) * input_width + j] * w11 +
input_data[(output_height - 1) * input_width + j + 1] * w12;
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] *
newscale_data[c] +
newbias_data[c];
// top right pad if (if_relu) {
float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); output_data[(output_height - 1) * output_width + j] =
float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); output_data[(output_height - 1) * output_width + j] < 0
? 0
tmp0 = vextq_f32(in0, pad0, 1); : output_data[(output_height - 1) * output_width + j];
tmp1 = vextq_f32(in0, pad0, 2); }
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
} }
for (int i = 0; i < c_mid; ++i) { #pragma omp parallel for
if (i == 0) { for (int i = 1; i < output_height - 1; i++) {
vst1q_lane_f32(output_ptr + i, out0, 0); for (int m = 1; (m + 3) < output_width - 1; m = m + 4) {
float *output_ptr = output_data + i * output_width + m;
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3,
tmp4, tmp5, out0;
in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1);
in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3);
in2 = vld1q_f32(input_data + i * input_width + m - 1);
in3 = vld1q_f32(input_data + i * input_width + m + 3);
in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1);
in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
} }
if (i == 1) { int m;
vst1q_lane_f32(output_ptr + i, out0, 1); for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
} }
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2); for (int j = m; j < output_width - 1; j++) {
output_data[i * output_width + j] =
input_data[(i - 1) * input_width + j - 1] * w00 +
input_data[(i - 1) * input_width + j] * w01 +
input_data[(i - 1) * input_width + j + 1] * w02 +
input_data[(i)*input_width + j - 1] * w10 +
input_data[(i)*input_width + j] * w11 +
input_data[(i)*input_width + j + 1] * w12 +
input_data[(i + 1) * input_width + j - 1] * w20 +
input_data[(i + 1) * input_width + j] * w21 +
input_data[(i + 1) * input_width + j + 1] * w22;
output_data[i * output_width + j] =
newscale_data[c] * output_data[i * output_width + j] +
newbias_data[c];
if (if_relu) {
output_data[i * output_width + j] =
output_data[i * output_width + j] < 0
? 0
: output_data[i * output_width + j];
}
} }
} }
// bottom right pad input_data = input_data + hxw;
float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); output_data = output_data + hxw;
float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); filter_data = filter_data + 9;
}
}
tmp0 = vextq_f32(in4, pad2, 1); /*
tmp1 = vextq_f32(in4, pad2, 2); const float *input_data = input->data<float>();
tmp2 = vextq_f32(in6, pad3, 1); const float *filter_data = filter->data<float>();
tmp3 = vextq_f32(in6, pad3, 2); float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[l] + w22 * input_data[l + 1];
output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1];
out0 = vmulq_n_f32(in4, w00); output_data[(l - 1) * l] =
out0 = vmlaq_n_f32(out0, tmp0, w01); w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
out0 = vmlaq_n_f32(out0, tmp1, w02); w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1];
out0 = vmlaq_n_f32(out0, in6, w10); output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] +
out0 = vmlaq_n_f32(out0, tmp2, w11); w01 * input_data[(l - 2) * (l + 1) + 1] +
out0 = vmlaq_n_f32(out0, tmp3, w12); w10 * input_data[l * l - 2] +
out0 = vmlaq_f32(vnewbias, vnewscale, out0); w11 * input_data[l * l - 1];
if (if_relu) { output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j];
out0 = vmaxq_f32(out0, vzero); output_data[l - 1] =
} output_data[l - 1] * newscale_data[j] + newbias_data[j];
for (int i = 0; i < c_mid; ++i) { output_data[(l - 1) * l] =
if (i == 0) { output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j];
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); output_data[l * l - 1] =
} output_data[l * l - 1] * newscale_data[j] + newbias_data[j];
if (i == 1) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1];
output_data[(l - 1) * l] =
output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l];
output_data[l * l - 1] =
output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1];
} }
if (i == 2) { for (int i = 1; i < l - 1; ++i) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); output_data[i * l] =
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1];
output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] +
w11 * input_data[i * l + l - 1] +
w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l];
output_data[i * l] =
output_data[i * l] * newscale_data[j] + newbias_data[j];
output_data[i * l + l - 1] =
output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j];
if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i *
l]; output_data[i * l + l - 1] =
output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1];
}
} }
}
// mid
for (int i = 0; i < l - 2; ++i) { // top 1 row and bottom 1 row
auto output_ptr = output_data + (i + 1) * l + 1; const float *input_tmp = input_data;
input_tmp = input_data + i * l;
auto in0_tmp = vld1q_f32(input_tmp); float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
auto in2_tmp = vld1q_f32(input_tmp + l); tmp3, tmp4, tmp5, out0;
auto in4_tmp = vld1q_f32(input_tmp + l + l); in0 = vld1q_f32(input_tmp);
c_mid = l_mid; in2 = vld1q_f32(input_tmp + l);
const float *input_tmp_end = input_tmp + (l - 2) * l;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + l);
int c_mid = l_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) { for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4); in1 = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4); in3 = vld1q_f32(input_tmp + l + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00); tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + l + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02); out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10); out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) { if (if_relu) {
out0 = vmaxq_f32(out0, vzero); out0 = vmaxq_f32(out0, vzero);
} }
vst1q_f32(output_ptr, out0); vst1q_f32(output_ptr + (l - 1) * l, out0);
output_ptr += 4; // can optimize to each 8 stride.
input_tmp += 4; input_tmp += 4;
in0_tmp = in1_tmp; input_tmp_end += 4;
in2_tmp = in3_tmp; output_ptr += 4;
in4_tmp = in5_tmp; in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
} }
float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); // top right pad
float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]);
tmp0 = vextq_f32(in0_tmp, pad0, 1); tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2); tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1); tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2); tmp3 = vextq_f32(in2, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00); out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]);
float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02); out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10); out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) { if (if_relu) {
out0 = vmaxq_f32(out0, vzero); out0 = vmaxq_f32(out0, vzero);
} }
for (int i = 0; i < c_mid; ++i) { for (int i = 0; i < c_mid; ++i) {
if (i == 0) { if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0); vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0);
} }
if (i == 1) { if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1); vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1);
} }
if (i == 2) { if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2); vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2);
}
}
// mid
for (int i = 0; i < l - 2; ++i) {
auto output_ptr = output_data + (i + 1) * l + 1;
input_tmp = input_data + i * l;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + l);
auto in4_tmp = vld1q_f32(input_tmp + l + l);
c_mid = l_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
} }
} }
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
} }
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
} }
} */
#endif #endif
} }
......
...@@ -33,6 +33,14 @@ float *packedA; ...@@ -33,6 +33,14 @@ float *packedA;
float *packedB; float *packedB;
float *packedC; float *packedC;
float *zero; float *zero;
typedef void (*FnPack)(int, int, int, const float *, int, float *);
typedef void (*FnAddDot)(int, const float *, const float *, float *, int);
FnPack procPackA;
FnPack procPackB;
FnAddDot procAddDot;
/* /*
// 将A矩阵分块复制到连续内存(ColMajor) // 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
...@@ -135,30 +143,32 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, ...@@ -135,30 +143,32 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer) {
const float *a0, *a1, *a2, *a3, *a4, *a5; const int i_length = m - m_tail;
for (int i = 0; i < m - m_tail; i += MR) { for (int i = 0; i < i_length; i += MR) {
a0 = A + i * lda; const float *a0 = A + i * lda;
a1 = A + (i + 1) * lda; const float *a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda; const float *a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda; const float *a3 = A + (i + 3) * lda;
a4 = A + (i + 4) * lda; const float *a4 = A + (i + 4) * lda;
a5 = A + (i + 5) * lda; const float *a5 = A + (i + 5) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
} }
} }
if (m_tail != 0) { if (m_tail != 0) {
a0 = &A(m - m_tail, 0); const float *a0 = &A(i_length, 0);
a1 = a0 + lda; const float *a1 = a0 + lda;
a2 = a0 + 2 * lda; const float *a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda; const float *a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda; const float *a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda; const float *a5 = a0 + 5 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) { switch (m_tail) {
case 1: case 1:
a1 = zero; a1 = zero;
...@@ -175,48 +185,105 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, ...@@ -175,48 +185,105 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
break; break;
} }
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
}
}
}
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const int i_length = m - m_tail;
#pragma omp parallel for
for (int i = 0; i < i_length; i += MR) {
const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda;
const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda;
const float *a4 = A + (i + 4) * lda;
const float *a5 = A + (i + 5) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
if (m_tail != 0) {
const float *a0 = &A(i_length, 0);
const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda;
const float *a3 = a0 + 3 * lda;
const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
} }
} }
} }
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer) {
const float *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; const int i_length = m - m_tail;
for (int i = 0; i < m - m_tail; i += MR) { for (int i = 0; i < i_length; i += MR) {
a0 = A + i * lda; const float *a0 = A + i * lda;
a1 = A + (i + 1) * lda; const float *a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda; const float *a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda; const float *a3 = A + (i + 3) * lda;
a4 = A + (i + 4) * lda; const float *a4 = A + (i + 4) * lda;
a5 = A + (i + 5) * lda; const float *a5 = A + (i + 5) * lda;
a6 = A + (i + 6) * lda; const float *a6 = A + (i + 6) * lda;
a7 = A + (i + 7) * lda; const float *a7 = A + (i + 7) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
*buffer++ = *a6++; *local_buffer++ = *a6++;
*buffer++ = *a7++; *local_buffer++ = *a7++;
} }
} }
if (m_tail != 0) { if (m_tail != 0) {
a0 = &A(m - m_tail, 0); const float *a0 = &A(i_length, 0);
a1 = a0 + lda; const float *a1 = a0 + lda;
a2 = a0 + 2 * lda; const float *a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda; const float *a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda; const float *a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda; const float *a5 = a0 + 5 * lda;
a6 = a0 + 6 * lda; const float *a6 = a0 + 6 * lda;
a7 = a0 + 7 * lda; const float *a7 = a0 + 7 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) { switch (m_tail) {
case 1: case 1:
a1 = zero; a1 = zero;
...@@ -237,14 +304,81 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, ...@@ -237,14 +304,81 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
break; break;
} }
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
*buffer++ = *a6++; *local_buffer++ = *a6++;
*buffer++ = *a7++; *local_buffer++ = *a7++;
}
}
}
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const int i_length = m - m_tail;
#pragma omp parallel for
for (int i = 0; i < i_length; i += MR) {
const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda;
const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda;
const float *a4 = A + (i + 4) * lda;
const float *a5 = A + (i + 5) * lda;
const float *a6 = A + (i + 6) * lda;
const float *a7 = A + (i + 7) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
*local_buffer++ = *a6++;
*local_buffer++ = *a7++;
}
}
if (m_tail != 0) {
const float *a0 = &A(i_length, 0);
const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda;
const float *a3 = a0 + 3 * lda;
const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda;
const float *a6 = a0 + 6 * lda;
const float *a7 = a0 + 7 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
case 6:
a6 = zero;
case 7:
a7 = zero;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
*local_buffer++ = *a6++;
*local_buffer++ = *a7++;
} }
} }
} }
...@@ -252,48 +386,102 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, ...@@ -252,48 +386,102 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
// 将B矩阵分块复制到连续内存(RowMajor) // 将B矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
const float *b0; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); const float *b0 = &B(i, j);
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
asm volatile( asm volatile(
"prfm pldl1keep, [%[b0]] \n\t" "prfm pldl1keep, [%[b0]] \n\t"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s}, [%[buffer]], #32 \n\t" "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1"); : "memory", "v0", "v1");
#else #else
asm volatile( asm volatile(
"pld [%[b0]] \n\t" "pld [%[b0]] \n\t"
"vld1.32 {q0, q1}, [%[b0]] \n\t" "vld1.32 {q0, q1}, [%[b0]] \n\t"
"vst1.32 {q0, q1}, [%[buffer]]! \n\t" "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "q0", "q1"); : "memory", "q0", "q1");
#endif // __aarch64__ #endif // __aarch64__
#else #else
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
#endif // __ARM_NEON #endif // __ARM_NEON
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail); const float *b0 = &B(i, j_length);
for (int j = n - n_tail; j < n; ++j) { for (int j = j_length; j < n; ++j) {
*buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int j = n; j < n + (NR - n_tail); ++j) { for (int j = n; j < j_length + NR; ++j) {
*buffer++ = 0; *local_buffer++ = 0;
}
}
}
}
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%[b0]] \n\t"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1");
#else
asm volatile(
"pld [%[b0]] \n\t"
"vld1.32 {q0, q1}, [%[b0]] \n\t"
"vst1.32 {q0, q1}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "q0", "q1");
#endif // __aarch64__
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
} }
} }
} }
...@@ -302,27 +490,60 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -302,27 +490,60 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
#if __aarch64__ #if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
const float *b0; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); const float *b0 = &B(i, j);
asm volatile( asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t" "prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s}, [%[buffer]], #48 \n\t" "st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1", "v2"); : "memory", "v0", "v1", "v2");
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail); const float *b0 = &B(i, j_length);
for (int j = n - n_tail; j < n; ++j) { for (int j = j_length; j < n; ++j) {
*buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int j = n; j < n + (NR - n_tail); ++j) { for (int j = n; j < j_length + NR; ++j) {
*buffer++ = 0; *local_buffer++ = 0;
}
}
}
}
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2");
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
} }
} }
} }
...@@ -330,27 +551,60 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -330,27 +551,60 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
const float *b0; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < n - n_tail; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); const float *b0 = &B(i, j);
asm volatile( asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t" "prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[buffer]], #64 \n\t" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1", "v2", "v3"); : "memory", "v0", "v1", "v2", "v3");
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail); const float *b0 = &B(i, j_length);
for (int j = n - n_tail; j < n; ++j) { for (int j = j_length; j < n; ++j) {
*buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int j = n; j < n + (NR - n_tail); ++j) { for (int j = n; j < j_length + NR; ++j) {
*buffer++ = 0; *local_buffer++ = 0;
}
}
}
}
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < n - n_tail; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2", "v3");
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
} }
} }
} }
...@@ -2244,6 +2498,27 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -2244,6 +2498,27 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
} }
} }
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {}
void WriteBasic(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {}
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {}
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias) {}
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias) {}
#endif // __ARM_NEON #endif // __ARM_NEON
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
...@@ -2373,6 +2648,221 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -2373,6 +2648,221 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(zero); paddle_mobile::memory::Free(zero);
} }
// 32位 float 矩阵乘法
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
int L1 = 32 * 1024;
KC = k;
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
// 补齐 B
NC = (n + NR - 1) / NR * NR;
#if __aarch64__
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_8c;
procAddDot = AddDot6x8;
#endif
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
procPackB(KC, NC, NC % NR, B, ldb, packedB);
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(float));
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
// 补齐 A
MC = (m + MR - 1) / MR * MR;
#if __aarch64__
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_8c;
procAddDot = AddDot6x8;
#endif
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
procPackA(MC, KC, MC % MR, A, lda, packedA);
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
}
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int i = 0; i < m; i += MC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int mc;
mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A);
InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, bias + i);
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int nc;
nc = s_min(n - j, NC);
float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B);
InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
int L1 = 32 * 1024;
KC = k;
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
// 补齐 B
NC = (n + NR - 1) / NR * NR;
#if __aarch64__
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_8c;
procAddDot = AddDot6x8;
#endif
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
procPackB(KC, NC, NC % NR, B, ldb, packedB);
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(float));
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
// 补齐 A
MC = (m + MR - 1) / MR * MR;
#if __aarch64__
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_8c;
procAddDot = AddDot6x8;
#endif
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
procPackA(MC, KC, MC % MR, A, lda, packedA);
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
}
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int i = 0; i < m; i += MC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int mc;
mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A);
InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0),
ldc, relu, new_scale + i, new_bias + i);
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int nc;
nc = s_min(n - j, NC);
float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B);
InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j),
ldc, relu, new_scale, new_bias);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
......
...@@ -50,6 +50,10 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, ...@@ -50,6 +50,10 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor) // 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
...@@ -58,6 +62,12 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -58,6 +62,12 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
...@@ -136,6 +146,16 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -136,6 +146,16 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias); bool relu, float *new_scale, float *new_bias);
// 32位 float 矩阵乘法(openmp 多线程版本)
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本)
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -42,8 +42,13 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -42,8 +42,13 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
int N = dim_out[1]; int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP
Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu, bias);
#else
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N, Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu, bias); beta, matrix_out->data<float>(), N, relu, bias);
#endif
} }
template <> template <>
...@@ -70,10 +75,17 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -70,10 +75,17 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
int N = dim_out[1]; int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP
SgemmWithBn_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, new_scale->data<float>() + group,
new_bias->data<float>() + group);
#else
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu, N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>() + group, new_scale->data<float>() + group,
new_bias->data<float>() + group); new_bias->data<float>() + group);
#endif
} }
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册