diff --git a/README.md b/README.md index 59ef597dd749ea16658977cd6d548cedaa90d166..c29165d57204561e702997187d55e6cf869c4b39 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,15 @@ Paddle-Moible是PaddlePaddle组织下的项目,是一个致力于嵌入式平 - **ARM CPU** - -![](http://mms-graph.bj.bcebos.com/paddle-mobile%2F2018_07_29.png) +|mobilenet arm v7|1线程|2线程|4线程| +|------------|----|-----|-----| +|麒麟960(ms)|110.586|72.474|49.833| +||||| +|mobilenetssd arm v7|1线程|2线程|4线程| +|麒麟960(ms)|224.464|142.544|96.068| +||||| +|googlenet(v1) arm v7|1线程|2线程|4线程| +|麒麟960(ms)|348.018|242.689|169.998| arm cpu是paddle-mobile的主要支持方向,cpu的通用性一直是其优势。嵌入式深度学习,需要大量的cpu汇编实现。我们正在紧锣密鼓的编码,为的是能充分硬件的每一点加速能力。 arm cpu的优化工作还在进行中,现在使用了常规的cpu优化。在arm a73上paddle-mobile arm-v7现在单核运行一次mobilenet1.0是110+ms,显然这不是我们的最终目标,我们正在用大量的汇编改写,后续性能仍会有巨大提升空间, 目前只支持armv7, 未来我们也会支持armv8。 diff --git a/src/io/api.cc b/src/io/api.cc index 2103c5317b8d15988b19d1c1bf07e1042a6453a0..0e254aa15ac06083038773d89c23d40242847782 100644 --- a/src/io/api.cc +++ b/src/io/api.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "cstring" #include "io/paddle_inference_api.h" namespace paddle_mobile { diff --git a/src/operators/kernel/fpga/elementwise_add_relu_kernel.cpp b/src/operators/kernel/fpga/elementwise_add_relu_kernel.cpp index 5dd8991e2a23540e81f043cd6199443d98098ff8..88a19beb41f67e5fc9336c8883c8ea75aaa939e0 100644 --- a/src/operators/kernel/fpga/elementwise_add_relu_kernel.cpp +++ b/src/operators/kernel/fpga/elementwise_add_relu_kernel.cpp @@ -25,9 +25,9 @@ bool ElementwiseAddReluKernel::Init( const Tensor *input_x = param->InputX(); const Tensor *input_y = param->InputY(); Tensor *out = param->Out(); - auto input_x_ptr = input_x->data(); - auto input_y_ptr = input_y->data(); - auto out_ptr = out->mutable_data(); + auto input_x_ptr = input_x->data(); + auto input_y_ptr = input_y->data(); + auto out_ptr = out->mutable_data(); fpga::EWAddArgs ewaddArgs; ewaddArgs.relu_enabled = relu_enabled; diff --git a/src/operators/kernel/fpga/fc_relu_kernel.cpp b/src/operators/kernel/fpga/fc_relu_kernel.cpp index 6b828f102412fb5aa8ef125c4ccb9b96598fc458..21e334b12b70be1980d9417ed11161143106d1c6 100644 --- a/src/operators/kernel/fpga/fc_relu_kernel.cpp +++ b/src/operators/kernel/fpga/fc_relu_kernel.cpp @@ -22,13 +22,13 @@ template <> bool FusionFcReluKernel::Init(FusionFcReluParam *param) { bool relu_enabled = true; const Tensor *input_x = param->InputX(); - auto input_x_ptr = input_x->data(); + auto input_x_ptr = input_x->data(); const Tensor *input_y = param->InputY(); auto input_y_ptr = input_y->data(); const Tensor *input_z = param->InputZ(); auto input_z_ptr = input_z->data(); Tensor *out = param->Out(); - auto out_ptr = out->mutable_data(); + auto out_ptr = out->mutable_data(); PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], "Image channel should be equal to weight number"); diff --git a/src/operators/kernel/fpga/fusion_fc_kernel.cpp b/src/operators/kernel/fpga/fusion_fc_kernel.cpp index 340561a9aa97ceda0bd37c40d721a0c5e3e535b4..505b8768565dc4003152c3493b558448f9d73d04 100644 --- a/src/operators/kernel/fpga/fusion_fc_kernel.cpp +++ b/src/operators/kernel/fpga/fusion_fc_kernel.cpp @@ -22,13 +22,13 @@ template <> bool FusionFcKernel::Init(FusionFcParam *param) { bool relu_enabled = false; const Tensor *input_x = param->InputX(); - auto input_x_ptr = input_x->data(); + auto input_x_ptr = input_x->data(); const Tensor *input_y = param->InputY(); auto input_y_ptr = input_y->data(); const Tensor *input_z = param->InputZ(); auto input_z_ptr = input_z->data(); Tensor *out = param->Out(); - auto out_ptr = out->mutable_data(); + auto out_ptr = out->mutable_data(); PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], "Image channel should be equal to weight number"); diff --git a/src/operators/kernel/fpga/pool_kernel.cpp b/src/operators/kernel/fpga/pool_kernel.cpp index 3e7dc5fd591fc85b98c7850102248c2264c62ba3..a7ff022c3b8616847c48a71bf94e4018cedcad2e 100644 --- a/src/operators/kernel/fpga/pool_kernel.cpp +++ b/src/operators/kernel/fpga/pool_kernel.cpp @@ -22,9 +22,9 @@ namespace operators { template <> bool PoolKernel::Init(PoolParam *param) { const Tensor *input = param->Input(); - auto input_ptr = input->data(); + auto input_ptr = input->data(); Tensor *output = param->Output(); - auto output_ptr = output->mutable_data(); + auto output_ptr = output->mutable_data(); vector ksize = param->Ksize(); vector strides = param->Strides(); vector paddings = param->Paddings(); diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index 7e353c29b80279f895ad6d0150b31eb1703d97d4..a67a24a4decddb76b654d9052946ce5de9a52b26 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -529,42 +529,42 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, const float *newscale_data = new_scale->data(); const float *newbias_data = new_bias->data(); - const int h = static_cast(input->dims()[2]); - const int w = static_cast(input->dims()[3]); - const int l = h; - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const int hxw = h * w; + const int input_channel = static_cast(input->dims()[1]); + + const int input_height = static_cast(input->dims()[2]); + const int input_width = static_cast(input->dims()[3]); + const int output_height = static_cast(output->dims()[2]); + const int output_width = static_cast(output->dims()[3]); + + const int hxw = input_height * input_width; + + const int l = input_height; float32x4_t vnewbias = vdupq_n_f32(0.0); float32x4_t vnewscale = vdupq_n_f32(1.0); float32x4_t vzero = vdupq_n_f32(0); - for (int b = 0; b < batch_size; ++b) { - const float *filter_data_tmp = filter_data; - - for (int j = 0; j < c; ++j) { - vnewbias = vdupq_n_f32(newbias_data[j]); - vnewscale = vdupq_n_f32(newscale_data[j]); - - int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; + for (int b = 0; b < batch_size; b++) { + filter_data = filter->data(); + for (int c = 0; c < input_channel; c++) { + vnewbias = vdupq_n_f32(newbias_data[c]); + vnewscale = vdupq_n_f32(newscale_data[c]); + + float w00 = filter_data[0]; + float w01 = filter_data[1]; + float w02 = filter_data[2]; + float w10 = filter_data[3]; + float w11 = filter_data[4]; + float w12 = filter_data[5]; + float w20 = filter_data[6]; + float w21 = filter_data[7]; + float w22 = filter_data[8]; output_data[0] = w11 * input_data[0] + w12 * input_data[1] + w21 * input_data[l] + w22 * input_data[l + 1]; - output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + w20 * input_data[2 * l - 2] + w21 * input_data[2 * l - 1]; - output_data[(l - 1) * l] = w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; @@ -572,13 +572,13 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, w01 * input_data[(l - 2) * (l + 1) + 1] + w10 * input_data[l * l - 2] + w11 * input_data[l * l - 1]; - output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j]; + output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c]; output_data[l - 1] = - output_data[l - 1] * newscale_data[j] + newbias_data[j]; + output_data[l - 1] * newscale_data[c] + newbias_data[c]; output_data[(l - 1) * l] = - output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j]; + output_data[(l - 1) * l] * newscale_data[c] + newbias_data[c]; output_data[l * l - 1] = - output_data[l * l - 1] * newscale_data[j] + newbias_data[j]; + output_data[l * l - 1] * newscale_data[c] + newbias_data[c]; if (if_relu) { output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; @@ -593,6 +593,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + w11 * input_data[i * l] + w12 * input_data[i * l + 1] + w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; + output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + w01 * input_data[i * l + l - 1 - l] + w10 * input_data[i * l + l - 1 - 1] + @@ -600,9 +601,9 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, w20 * input_data[i * l + l - 1 + l - 1] + w21 * input_data[i * l + l - 1 + l]; output_data[i * l] = - output_data[i * l] * newscale_data[j] + newbias_data[j]; + output_data[i * l] * newscale_data[c] + newbias_data[c]; output_data[i * l + l - 1] = - output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j]; + output_data[i * l + l - 1] * newscale_data[c] + newbias_data[c]; if (if_relu) { output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l]; @@ -611,28 +612,19 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, } } - // top 1 row and bottom 1 row - const float *input_tmp = input_data; - - float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, - tmp3, tmp4, tmp5, out0; - in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + l); - const float *input_tmp_end = input_tmp + (l - 2) * l; - in4 = vld1q_f32(input_tmp_end); - in6 = vld1q_f32(input_tmp_end + l); - int c_mid = l_mid; - auto output_ptr = output_data + 1; - for (; c_mid > 3; c_mid -= 4) { - in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + l + 4); + int m; + for (m = 1; m < output_width - 4; m += 4) { + float *output_ptr = output_data + m; + float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; + in0 = vld1q_f32(input_data + m - 1); + in1 = vld1q_f32(input_data + m + 3); + in2 = vld1q_f32(input_data + input_width + m - 1); + in3 = vld1q_f32(input_data + input_width + m + 3); tmp0 = vextq_f32(in0, in1, 1); tmp1 = vextq_f32(in0, in1, 2); - tmp2 = vextq_f32(in2, in3, 1); tmp3 = vextq_f32(in2, in3, 2); - out0 = vmulq_n_f32(in0, w10); out0 = vmlaq_n_f32(out0, tmp0, w11); out0 = vmlaq_n_f32(out0, tmp1, w12); @@ -644,182 +636,438 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, out0 = vmaxq_f32(out0, vzero); } vst1q_f32(output_ptr, out0); + } + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { + } + for (int j = m; j < output_width - 1; j++) { + output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 + + input_data[j + 1] * w12 + + input_data[input_width + j - 1] * w20 + + input_data[input_width + j] * w21 + + input_data[input_width + j + 1] * w22; + output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c]; - in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + l + 4); + if (if_relu) { + output_data[j] = output_data[j] < 0 ? 0 : output_data[j]; + } + } - tmp0 = vextq_f32(in4, in5, 1); - tmp1 = vextq_f32(in4, in5, 2); - tmp2 = vextq_f32(in6, in7, 1); - tmp3 = vextq_f32(in6, in7, 2); + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { + float *output_ptr = + output_data + (output_height - 1) * output_width + m; - out0 = vmulq_n_f32(in4, w00); + float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; + in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1); + in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3); + in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1); + in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3); + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + out0 = vmulq_n_f32(in0, w00); out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); + out0 = vmlaq_n_f32(out0, in2, w10); out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_f32(vnewbias, vnewscale, out0); if (if_relu) { out0 = vmaxq_f32(out0, vzero); } - vst1q_f32(output_ptr + (l - 1) * l, out0); - - // can optimize to each 8 stride. - input_tmp += 4; - input_tmp_end += 4; - output_ptr += 4; - in0 = in1; - in2 = in3; - in4 = in5; - in6 = in7; + vst1q_f32(output_ptr, out0); } + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { + } + for (int j = m; j < output_width - 1; j++) { + output_data[(output_height - 1) * input_width + j] = + input_data[(output_height - 2) * input_width + j - 1] * w00 + + input_data[(output_height - 2) * input_width + j] * w01 + + input_data[(output_height - 2) * input_width + j + 1] * w02 + + input_data[(output_height - 1) * input_width + j - 1] * w10 + + input_data[(output_height - 1) * input_width + j] * w11 + + input_data[(output_height - 1) * input_width + j + 1] * w12; + output_data[(output_height - 1) * output_width + j] = + output_data[(output_height - 1) * output_width + j] * + newscale_data[c] + + newbias_data[c]; - // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); - - tmp0 = vextq_f32(in0, pad0, 1); - tmp1 = vextq_f32(in0, pad0, 2); - tmp2 = vextq_f32(in2, pad1, 1); - tmp3 = vextq_f32(in2, pad1, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); + if (if_relu) { + output_data[(output_height - 1) * output_width + j] = + output_data[(output_height - 1) * output_width + j] < 0 + ? 0 + : output_data[(output_height - 1) * output_width + j]; + } } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); + #pragma omp parallel for + for (int i = 1; i < output_height - 1; i++) { + for (int m = 1; (m + 3) < output_width - 1; m = m + 4) { + float *output_ptr = output_data + i * output_width + m; + float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, + tmp4, tmp5, out0; + in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1); + in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3); + in2 = vld1q_f32(input_data + i * input_width + m - 1); + in3 = vld1q_f32(input_data + i * input_width + m + 3); + in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1); + in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3); + + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + tmp4 = vextq_f32(in4, in5, 1); + tmp5 = vextq_f32(in4, in5, 2); + + out0 = vmulq_n_f32(in0, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + vst1q_f32(output_ptr, out0); } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); + int m; + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); + + for (int j = m; j < output_width - 1; j++) { + output_data[i * output_width + j] = + input_data[(i - 1) * input_width + j - 1] * w00 + + input_data[(i - 1) * input_width + j] * w01 + + input_data[(i - 1) * input_width + j + 1] * w02 + + input_data[(i)*input_width + j - 1] * w10 + + input_data[(i)*input_width + j] * w11 + + input_data[(i)*input_width + j + 1] * w12 + + input_data[(i + 1) * input_width + j - 1] * w20 + + input_data[(i + 1) * input_width + j] * w21 + + input_data[(i + 1) * input_width + j + 1] * w22; + output_data[i * output_width + j] = + newscale_data[c] * output_data[i * output_width + j] + + newbias_data[c]; + if (if_relu) { + output_data[i * output_width + j] = + output_data[i * output_width + j] < 0 + ? 0 + : output_data[i * output_width + j]; + } } } - // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); - float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + input_data = input_data + hxw; + output_data = output_data + hxw; + filter_data = filter_data + 9; + } + } - tmp0 = vextq_f32(in4, pad2, 1); - tmp1 = vextq_f32(in4, pad2, 2); - tmp2 = vextq_f32(in6, pad3, 1); - tmp3 = vextq_f32(in6, pad3, 2); + /* + const float *input_data = input->data(); + const float *filter_data = filter->data(); + float *output_data = output->data(); + const float *newscale_data = new_scale->data(); + const float *newbias_data = new_bias->data(); + + const int h = static_cast(input->dims()[2]); + const int w = static_cast(input->dims()[3]); + const int l = h; + + const int batch_size = static_cast(input->dims()[0]); + const int c = static_cast(input->dims()[1]); + const int hxw = h * w; + float32x4_t vnewbias = vdupq_n_f32(0.0); + float32x4_t vnewscale = vdupq_n_f32(1.0); + float32x4_t vzero = vdupq_n_f32(0); + + for (int b = 0; b < batch_size; ++b) { + const float *filter_data_tmp = filter_data; + + for (int j = 0; j < c; ++j) { + vnewbias = vdupq_n_f32(newbias_data[j]); + vnewscale = vdupq_n_f32(newscale_data[j]); + + int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 + float w00 = filter_data_tmp[0]; + float w01 = filter_data_tmp[1]; + float w02 = filter_data_tmp[2]; + float w10 = filter_data_tmp[3]; + float w11 = filter_data_tmp[4]; + float w12 = filter_data_tmp[5]; + float w20 = filter_data_tmp[6]; + float w21 = filter_data_tmp[7]; + float w22 = filter_data_tmp[8]; + + output_data[0] = w11 * input_data[0] + w12 * input_data[1] + + w21 * input_data[l] + w22 * input_data[l + 1]; + + output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + + w20 * input_data[2 * l - 2] + + w21 * input_data[2 * l - 1]; - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); + output_data[(l - 1) * l] = + w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + + w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; + output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + + w01 * input_data[(l - 2) * (l + 1) + 1] + + w10 * input_data[l * l - 2] + + w11 * input_data[l * l - 1]; + output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j]; + output_data[l - 1] = + output_data[l - 1] * newscale_data[j] + newbias_data[j]; + output_data[(l - 1) * l] = + output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j]; + output_data[l * l - 1] = + output_data[l * l - 1] * newscale_data[j] + newbias_data[j]; + + if (if_relu) { + output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; + output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1]; + output_data[(l - 1) * l] = + output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l]; + output_data[l * l - 1] = + output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1]; } - if (i == 2) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + for (int i = 1; i < l - 1; ++i) { + output_data[i * l] = + w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + + w11 * input_data[i * l] + w12 * input_data[i * l + 1] + + w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; + output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + + w01 * input_data[i * l + l - 1 - l] + + w10 * input_data[i * l + l - 1 - 1] + + w11 * input_data[i * l + l - 1] + + w20 * input_data[i * l + l - 1 + l - 1] + + w21 * input_data[i * l + l - 1 + l]; + output_data[i * l] = + output_data[i * l] * newscale_data[j] + newbias_data[j]; + output_data[i * l + l - 1] = + output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j]; + + if (if_relu) { + output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * + l]; output_data[i * l + l - 1] = + output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1]; + } } - } - // mid - for (int i = 0; i < l - 2; ++i) { - auto output_ptr = output_data + (i + 1) * l + 1; - input_tmp = input_data + i * l; - auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + l); - auto in4_tmp = vld1q_f32(input_tmp + l + l); - c_mid = l_mid; + // top 1 row and bottom 1 row + const float *input_tmp = input_data; + + float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, + tmp3, tmp4, tmp5, out0; + in0 = vld1q_f32(input_tmp); + in2 = vld1q_f32(input_tmp + l); + const float *input_tmp_end = input_tmp + (l - 2) * l; + in4 = vld1q_f32(input_tmp_end); + in6 = vld1q_f32(input_tmp_end + l); + int c_mid = l_mid; + auto output_ptr = output_data + 1; for (; c_mid > 3; c_mid -= 4) { - auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + l + 4); - auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + in1 = vld1q_f32(input_tmp + 4); + in3 = vld1q_f32(input_tmp + l + 4); - tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); - tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); - tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); - tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); - tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); - tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); - out0 = vmulq_n_f32(in0_tmp, w00); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + + out0 = vmulq_n_f32(in0, w10); + out0 = vmlaq_n_f32(out0, tmp0, w11); + out0 = vmlaq_n_f32(out0, tmp1, w12); + out0 = vmlaq_n_f32(out0, in2, w20); + out0 = vmlaq_n_f32(out0, tmp2, w21); + out0 = vmlaq_n_f32(out0, tmp3, w22); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + vst1q_f32(output_ptr, out0); + + in5 = vld1q_f32(input_tmp_end + 4); + in7 = vld1q_f32(input_tmp_end + l + 4); + + tmp0 = vextq_f32(in4, in5, 1); + tmp1 = vextq_f32(in4, in5, 2); + tmp2 = vextq_f32(in6, in7, 1); + tmp3 = vextq_f32(in6, in7, 2); + + out0 = vmulq_n_f32(in4, w00); out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, in6, w10); out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); out0 = vmlaq_f32(vnewbias, vnewscale, out0); if (if_relu) { out0 = vmaxq_f32(out0, vzero); } - vst1q_f32(output_ptr, out0); + vst1q_f32(output_ptr + (l - 1) * l, out0); - output_ptr += 4; + // can optimize to each 8 stride. input_tmp += 4; - in0_tmp = in1_tmp; - in2_tmp = in3_tmp; - in4_tmp = in5_tmp; + input_tmp_end += 4; + output_ptr += 4; + in0 = in1; + in2 = in3; + in4 = in5; + in6 = in7; } - float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + // top right pad + float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); - tmp0 = vextq_f32(in0_tmp, pad0, 1); - tmp1 = vextq_f32(in0_tmp, pad0, 2); - tmp2 = vextq_f32(in2_tmp, pad1, 1); - tmp3 = vextq_f32(in2_tmp, pad1, 2); - tmp4 = vextq_f32(in4_tmp, pad2, 1); - tmp5 = vextq_f32(in4_tmp, pad2, 2); + tmp0 = vextq_f32(in0, pad0, 1); + tmp1 = vextq_f32(in0, pad0, 2); + tmp2 = vextq_f32(in2, pad1, 1); + tmp3 = vextq_f32(in2, pad1, 2); - out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmulq_n_f32(in0, w10); + out0 = vmlaq_n_f32(out0, tmp0, w11); + out0 = vmlaq_n_f32(out0, tmp1, w12); + out0 = vmlaq_n_f32(out0, in2, w20); + out0 = vmlaq_n_f32(out0, tmp2, w21); + out0 = vmlaq_n_f32(out0, tmp3, w22); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } + } + + // bottom right pad + float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); + float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + + tmp0 = vextq_f32(in4, pad2, 1); + tmp1 = vextq_f32(in4, pad2, 2); + tmp2 = vextq_f32(in6, pad3, 1); + tmp3 = vextq_f32(in6, pad3, 2); + + out0 = vmulq_n_f32(in4, w00); out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, in6, w10); out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); out0 = vmlaq_f32(vnewbias, vnewscale, out0); if (if_relu) { out0 = vmaxq_f32(out0, vzero); } for (int i = 0; i < c_mid; ++i) { if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); } if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); } if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + } + } + // mid + + + for (int i = 0; i < l - 2; ++i) { + auto output_ptr = output_data + (i + 1) * l + 1; + input_tmp = input_data + i * l; + auto in0_tmp = vld1q_f32(input_tmp); + auto in2_tmp = vld1q_f32(input_tmp + l); + auto in4_tmp = vld1q_f32(input_tmp + l + l); + c_mid = l_mid; + for (; c_mid > 3; c_mid -= 4) { + auto in1_tmp = vld1q_f32(input_tmp + 4); + auto in3_tmp = vld1q_f32(input_tmp + l + 4); + auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + + tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); + tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); + tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); + tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); + tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); + tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); + + out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4_tmp, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + vst1q_f32(output_ptr, out0); + + output_ptr += 4; + input_tmp += 4; + in0_tmp = in1_tmp; + in2_tmp = in3_tmp; + in4_tmp = in5_tmp; + } + + float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); + float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + + tmp0 = vextq_f32(in0_tmp, pad0, 1); + tmp1 = vextq_f32(in0_tmp, pad0, 2); + tmp2 = vextq_f32(in2_tmp, pad1, 1); + tmp3 = vextq_f32(in2_tmp, pad1, 2); + tmp4 = vextq_f32(in4_tmp, pad2, 1); + tmp5 = vextq_f32(in4_tmp, pad2, 2); + + out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4_tmp, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } } } + output_data += hxw; + input_data += hxw; + filter_data_tmp += 9; } - output_data += hxw; - input_data += hxw; - filter_data_tmp += 9; } - } + */ #endif } diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 2b7363b0efa551b83d933d9ab2079dc7dcbaf65d..3730cf350a1399e5f3c1473fd1ce8d7b1d13b1b6 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -33,6 +33,14 @@ float *packedA; float *packedB; float *packedC; float *zero; + +typedef void (*FnPack)(int, int, int, const float *, int, float *); +typedef void (*FnAddDot)(int, const float *, const float *, float *, int); + +FnPack procPackA; +FnPack procPackB; +FnAddDot procAddDot; + /* // 将A矩阵分块复制到连续内存(ColMajor) void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, @@ -135,30 +143,32 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, float *buffer) { - const float *a0, *a1, *a2, *a3, *a4, *a5; - for (int i = 0; i < m - m_tail; i += MR) { - a0 = A + i * lda; - a1 = A + (i + 1) * lda; - a2 = A + (i + 2) * lda; - a3 = A + (i + 3) * lda; - a4 = A + (i + 4) * lda; - a5 = A + (i + 5) * lda; + const int i_length = m - m_tail; + for (int i = 0; i < i_length; i += MR) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + float *local_buffer = buffer + i * k; for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - *buffer++ = *a4++; - *buffer++ = *a5++; + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; } } if (m_tail != 0) { - a0 = &A(m - m_tail, 0); - a1 = a0 + lda; - a2 = a0 + 2 * lda; - a3 = a0 + 3 * lda; - a4 = a0 + 4 * lda; - a5 = a0 + 5 * lda; + const float *a0 = &A(i_length, 0); + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + float *local_buffer = buffer + i_length * k; switch (m_tail) { case 1: a1 = zero; @@ -175,48 +185,105 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, break; } for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - *buffer++ = *a4++; - *buffer++ = *a5++; + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } +} + +void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, + float *buffer) { + const int i_length = m - m_tail; +#pragma omp parallel for + for (int i = 0; i < i_length; i += MR) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + float *local_buffer = buffer + i * k; + for (int j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } + if (m_tail != 0) { + const float *a0 = &A(i_length, 0); + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + float *local_buffer = buffer + i_length * k; + switch (m_tail) { + case 1: + a1 = zero; + case 2: + a2 = zero; + case 3: + a3 = zero; + case 4: + a4 = zero; + case 5: + a5 = zero; + break; + default: + break; + } + for (int j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; } } } void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, float *buffer) { - const float *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; - for (int i = 0; i < m - m_tail; i += MR) { - a0 = A + i * lda; - a1 = A + (i + 1) * lda; - a2 = A + (i + 2) * lda; - a3 = A + (i + 3) * lda; - a4 = A + (i + 4) * lda; - a5 = A + (i + 5) * lda; - a6 = A + (i + 6) * lda; - a7 = A + (i + 7) * lda; + const int i_length = m - m_tail; + for (int i = 0; i < i_length; i += MR) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + const float *a6 = A + (i + 6) * lda; + const float *a7 = A + (i + 7) * lda; + float *local_buffer = buffer + i * k; for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - *buffer++ = *a4++; - *buffer++ = *a5++; - *buffer++ = *a6++; - *buffer++ = *a7++; + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + *local_buffer++ = *a6++; + *local_buffer++ = *a7++; } } if (m_tail != 0) { - a0 = &A(m - m_tail, 0); - a1 = a0 + lda; - a2 = a0 + 2 * lda; - a3 = a0 + 3 * lda; - a4 = a0 + 4 * lda; - a5 = a0 + 5 * lda; - a6 = a0 + 6 * lda; - a7 = a0 + 7 * lda; + const float *a0 = &A(i_length, 0); + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + const float *a6 = a0 + 6 * lda; + const float *a7 = a0 + 7 * lda; + float *local_buffer = buffer + i_length * k; switch (m_tail) { case 1: a1 = zero; @@ -237,14 +304,81 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, break; } for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - *buffer++ = *a4++; - *buffer++ = *a5++; - *buffer++ = *a6++; - *buffer++ = *a7++; + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + *local_buffer++ = *a6++; + *local_buffer++ = *a7++; + } + } +} + +void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, + float *buffer) { + const int i_length = m - m_tail; +#pragma omp parallel for + for (int i = 0; i < i_length; i += MR) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + const float *a6 = A + (i + 6) * lda; + const float *a7 = A + (i + 7) * lda; + float *local_buffer = buffer + i * k; + for (int j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + *local_buffer++ = *a6++; + *local_buffer++ = *a7++; + } + } + if (m_tail != 0) { + const float *a0 = &A(i_length, 0); + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + const float *a6 = a0 + 6 * lda; + const float *a7 = a0 + 7 * lda; + float *local_buffer = buffer + i_length * k; + switch (m_tail) { + case 1: + a1 = zero; + case 2: + a2 = zero; + case 3: + a3 = zero; + case 4: + a4 = zero; + case 5: + a5 = zero; + case 6: + a6 = zero; + case 7: + a7 = zero; + break; + default: + break; + } + for (int j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + *local_buffer++ = *a6++; + *local_buffer++ = *a7++; } } } @@ -252,48 +386,102 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, // 将B矩阵分块复制到连续内存(RowMajor) void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, float *buffer) { - const float *b0; - for (int j = 0; j < n - n_tail; j += NR) { + const int j_length = n - n_tail; + for (int j = 0; j < j_length; j += NR) { + float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, j); + const float *b0 = &B(i, j); #if __ARM_NEON #if __aarch64__ asm volatile( "prfm pldl1keep, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s}, [%[buffer]], #32 \n\t" - : [buffer] "+r"(buffer) + "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" + : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "v0", "v1"); #else asm volatile( "pld [%[b0]] \n\t" "vld1.32 {q0, q1}, [%[b0]] \n\t" - "vst1.32 {q0, q1}, [%[buffer]]! \n\t" - : [buffer] "+r"(buffer) + "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "q0", "q1"); #endif // __aarch64__ #else - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; #endif // __ARM_NEON } } if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, n - n_tail); - for (int j = n - n_tail; j < n; ++j) { - *buffer++ = *b0++; + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; } - for (int j = n; j < n + (NR - n_tail); ++j) { - *buffer++ = 0; + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { + const int j_length = n - n_tail; +#pragma omp parallel for + for (int j = 0; j < j_length; j += NR) { + float *local_buffer = buffer + j * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j); +#if __ARM_NEON +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n\t" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" + "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "v0", "v1"); +#else + asm volatile( + "pld [%[b0]] \n\t" + "vld1.32 {q0, q1}, [%[b0]] \n\t" + "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "q0", "q1"); +#endif // __aarch64__ +#else + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; +#endif // __ARM_NEON + } + } + if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; } } } @@ -302,27 +490,60 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, #if __aarch64__ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, float *buffer) { - const float *b0; - for (int j = 0; j < n - n_tail; j += NR) { + const int j_length = n - n_tail; + for (int j = 0; j < j_length; j += NR) { + float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, j); + const float *b0 = &B(i, j); asm volatile( "prfm pldl2keep, [%[b0], #64] \n\t" "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s, v2.4s}, [%[buffer]], #48 \n\t" - : [buffer] "+r"(buffer) + "st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t" + : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "v0", "v1", "v2"); } } if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, n - n_tail); - for (int j = n - n_tail; j < n; ++j) { - *buffer++ = *b0++; + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; } - for (int j = n; j < n + (NR - n_tail); ++j) { - *buffer++ = 0; + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { + const int j_length = n - n_tail; +#pragma omp parallel for + for (int j = 0; j < j_length; j += NR) { + float *local_buffer = buffer + j * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j); + asm volatile( + "prfm pldl2keep, [%[b0], #64] \n\t" + "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" + "st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "v0", "v1", "v2"); + } + } + if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; } } } @@ -330,27 +551,60 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, float *buffer) { - const float *b0; + const int j_length = n - n_tail; for (int j = 0; j < n - n_tail; j += NR) { + float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, j); + const float *b0 = &B(i, j); asm volatile( "prfm pldl2keep, [%[b0], #64] \n\t" "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[buffer]], #64 \n\t" - : [buffer] "+r"(buffer) + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t" + : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "v0", "v1", "v2", "v3"); } } if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, n - n_tail); - for (int j = n - n_tail; j < n; ++j) { - *buffer++ = *b0++; + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; } - for (int j = n; j < n + (NR - n_tail); ++j) { - *buffer++ = 0; + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { + const int j_length = n - n_tail; +#pragma omp parallel for + for (int j = 0; j < n - n_tail; j += NR) { + float *local_buffer = buffer + j * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j); + asm volatile( + "prfm pldl2keep, [%[b0], #64] \n\t" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "v0", "v1", "v2", "v3"); + } + } + if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; } } } @@ -2244,6 +2498,27 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { } } +void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {} + +void WriteBasic(int mc, int nc, float *c, float *C, int ldc) {} + +void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} + +void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {} + +void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {} + +void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {} + +void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, + float *bias) {} + +void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, + float *new_bias) {} + +void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, + float *new_scale, float *new_bias) {} + #endif // __ARM_NEON // 32位 float 矩阵乘法 @@ -2373,6 +2648,221 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, paddle_mobile::memory::Free(zero); } +// 32位 float 矩阵乘法 +void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *bias) { +#ifdef _OPENMP + int max_threads = omp_get_max_threads(); +#else + int max_threads = 1; +#endif + + int L1 = 32 * 1024; + KC = k; + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(float)); + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + // 补齐 B + NC = (n + NR - 1) / NR * NR; + +#if __aarch64__ + procPackA = PackMatrixA_6r; + procPackB = PackMatrixB_omp_16c; + procAddDot = AddDot6x16; +#else + procPackA = PackMatrixA_6r; + procPackB = PackMatrixB_omp_8c; + procAddDot = AddDot6x8; +#endif + + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + procPackB(KC, NC, NC % NR, B, ldb, packedB); + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(float)); + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + // 补齐 A + MC = (m + MR - 1) / MR * MR; + +#if __aarch64__ + procPackA = PackMatrixA_omp_6r; + procPackB = PackMatrixB_16c; + procAddDot = AddDot6x16; +#else + procPackA = PackMatrixA_omp_6r; + procPackB = PackMatrixB_8c; + procAddDot = AddDot6x8; +#endif + + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + procPackA(MC, KC, MC % MR, A, lda, packedA); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); + } + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); + + if (m > n) { +#pragma omp parallel for + for (int i = 0; i < m; i += MC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif + + int mc; + mc = s_min(m - i, MC); + float *local_A = packedA + MC * KC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A); + InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C, + &C(i, 0), ldc, relu, bias + i); + } + } else { +#pragma omp parallel for + for (int j = 0; j < n; j += NC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif + + int nc; + nc = s_min(n - j, NC); + float *local_B = packedB + KC * NC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B); + InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C, + &C(0, j), ldc, relu, bias); + } + } + + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); +} + +void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *new_scale, float *new_bias) { +#ifdef _OPENMP + int max_threads = omp_get_max_threads(); +#else + int max_threads = 1; +#endif + + int L1 = 32 * 1024; + KC = k; + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(float)); + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + // 补齐 B + NC = (n + NR - 1) / NR * NR; + +#if __aarch64__ + procPackA = PackMatrixA_6r; + procPackB = PackMatrixB_omp_16c; + procAddDot = AddDot6x16; +#else + procPackA = PackMatrixA_6r; + procPackB = PackMatrixB_omp_8c; + procAddDot = AddDot6x8; +#endif + + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + procPackB(KC, NC, NC % NR, B, ldb, packedB); + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(float)); + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + // 补齐 A + MC = (m + MR - 1) / MR * MR; + +#if __aarch64__ + procPackA = PackMatrixA_omp_6r; + procPackB = PackMatrixB_16c; + procAddDot = AddDot6x16; +#else + procPackA = PackMatrixA_omp_6r; + procPackB = PackMatrixB_8c; + procAddDot = AddDot6x8; +#endif + + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + procPackA(MC, KC, MC % MR, A, lda, packedA); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); + } + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); + + if (m > n) { +#pragma omp parallel for + for (int i = 0; i < m; i += MC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif + + int mc; + mc = s_min(m - i, MC); + float *local_A = packedA + MC * KC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A); + InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0), + ldc, relu, new_scale + i, new_bias + i); + } + } else { +#pragma omp parallel for + for (int j = 0; j < n; j += NC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif + + int nc; + nc = s_min(n - j, NC); + float *local_B = packedB + KC * NC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B); + InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j), + ldc, relu, new_scale, new_bias); + } + } + + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); +} + void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { #if __ARM_NEON #if __aarch64__ diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 625fce0323580545c1655c1d3c325f995aa054f2..40199faa4c30ec965a3980f44f1dbb6ae7d6799b 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -50,6 +50,10 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, float *buffer); void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, float *buffer); +void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, + float *buffer); +void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, + float *buffer); // 将 B 矩阵分块复制到连续内存(RowMajor) void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, @@ -58,6 +62,12 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); +void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); +void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); +void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); // 分块矩阵乘法 void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, @@ -136,6 +146,16 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *new_scale, float *new_bias); +// 32位 float 矩阵乘法(openmp 多线程版本) +void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *bias); + +// 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本) +void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *new_scale, float *new_bias); + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index 9ac8d79e89b7a577f0a89807dc96c9f368fed6de..381624250af87f4eeff7cf316a2f0f346c399137 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -42,8 +42,13 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, int N = dim_out[1]; int K = (!trans_a) ? dim_a[1] : dim_a[0]; +#ifdef _OPENMP + Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), + N, beta, matrix_out->data(), N, relu, bias); +#else Sgemm(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, bias); +#endif } template <> @@ -70,10 +75,17 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, int N = dim_out[1]; int K = (!trans_a) ? dim_a[1] : dim_a[0]; +#ifdef _OPENMP + SgemmWithBn_omp(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), N, + relu, new_scale->data() + group, + new_bias->data() + group); +#else SgemmWithBn(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, new_scale->data() + group, new_bias->data() + group); +#endif } } // namespace math