提交 2fc1a20a 编写于 作者: 朔-望's avatar 朔-望 提交者: GitHub

Merge branch 'develop' into develop

......@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.0)
project(paddle-mobile)
option(DEBUGING "enable debug mode" ON)
option(USE_OPENMP "openmp support" OFF)
option(USE_OPENMP "openmp support" ON)
option(USE_EXCEPTION "use std exception" ON)
option(LOG_PROFILE "log profile" ON)
# select the platform to build
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "io/executor.h"
#include <operators/math/gemm.h>
#include <algorithm>
#include <vector>
#include "common/enforce.h"
......@@ -25,6 +26,9 @@ limitations under the License. */
#include "framework/program/var_desc.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <queue>
#include <utility>
......@@ -403,6 +407,17 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict(
return result_vector;
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::SetThreadNum(int num) {
for (int k = 0; k < std::max(num, 3); ++k) {
operators::math::Gemmer::gemmers.push_back(new operators::math::Gemmer());
}
#ifdef _OPENMP
// omp_set_dynamic(0);
omp_set_num_threads(num);
#endif
}
template class Executor<CPU, Precision::FP32>;
template class Executor<FPGA, Precision::FP32>;
template class Executor<GPU_MALI, Precision::FP32>;
......
......@@ -58,6 +58,8 @@ class Executor {
std::vector<Ptype> Predict(const std::vector<Ptype> &input,
const std::vector<int64_t> &dims);
void SetThreadNum(int num);
protected:
Executor() = default;
void InitMemory();
......
......@@ -79,11 +79,11 @@ class FusionConvAddBNReluOp
#ifdef PADDLE_MOBILE_CPU
//#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
// static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
// new FusionConvAddBNReluMatcher());
//#define FUSION_CONV_ADD_BN_RELU_REGISTER
//#endif
#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
new FusionConvAddBNReluMatcher());
#define FUSION_CONV_ADD_BN_RELU_REGISTER
#endif
#endif
......
......@@ -14,10 +14,14 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP
#pragma once
#if _OPENMP
#include <omp.h>
#endif
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/gemm.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
......@@ -106,9 +110,33 @@ void ConvAddBasic(const FusionConvAddParam &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1));
auto dim_a = filter_slice.dims();
auto dim_b = col_matrix.dims();
auto dim_out = out_slice.dims();
int m = dim_out[0];
int n = dim_out[1];
int k = dim_a[1];
float *output_data = out_slice.data<float>();
int thread_num = 4;
int m1 = m / thread_num;
int m2 = m % thread_num;
#pragma omp parallel for
for (int j = 0; j < thread_num; ++j) {
int row_count = m1;
if (j == thread_num - 1) {
row_count = m1 + m2;
}
math::Gemmer::gemmers[j]->Sgemm(
row_count, n, k, 1, filter_slice.data<float>() + j * m1 * k, k,
col_matrix.data<float>(), n, 1, output_data + j * m1 * n, n, false);
}
// math::matmul<float>(filter_slice, false, col_matrix, false,
// static_cast<float>(1), &out_slice,
// static_cast<float>(1));
}
}
}
......@@ -124,9 +152,15 @@ void ConvAddCompute(const FusionConvAddParam &param) {
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), param.Bias(), param.Output(), true);
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
// param.Output(), false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true);
} else {
ConvAddBasic(param);
}
......
......@@ -26,8 +26,6 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
Tensor bias = *param.Bias();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
auto new_bias_ptr = new_bias.data<float>();
auto new_scale_ptr = new_scale.data<float>();
int axis = param.Axis();
Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
......@@ -106,20 +104,10 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
}
}
/// todo : use neon in special case instead of 2for(300ms)
auto output_ptr = output->data<float>();
for (int c = 0; c < output_matrix_shape[0]; c++) {
int start = c * output_matrix_shape[1];
for (int j = 0; j < output_matrix_shape[1]; j++) {
output_ptr[start + j] =
output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c];
output_ptr[start + j] =
output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias);
}
}
}
......@@ -138,9 +126,12 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), 1);
param.NewBias(), true);
} else {
ConvAddBNReluBasic(param);
}
......
......@@ -37,8 +37,12 @@ void DepthwiseConvCompute(const ConvParam &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), &Bias, param.Output(), false);
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), &Bias, param.Output(), false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
Bias, false);
} else {
ConvBasic(param);
}
......
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LRN_OP
#ifdef _OPENMP
#include <omp.h>
#endif
#include "framework/operator.h"
#include "operators/op_param.h"
......@@ -47,6 +49,7 @@ struct LRNFunctor {
std::fill(sqr_buffer_ptr, sqr_buffer_ptr + sqr_buffer.numel(), 0.0);
for (int a = 0; a < N; a++) {
#pragma parallel for
for (int b = 0; b < C; b++) {
for (int index = start; index < end; index++) {
int channel = b + index;
......
......@@ -1010,6 +1010,442 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
output_data += output_batch_stride;
}
}
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias) {
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *bias_data = bias.data<float>();
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]);
const int out_l = out_h;
const int in_l = in_h;
const int inhxw = in_h * in_w;
const int outhxw = out_h * out_w;
const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const float *input_row_ptr;
float *output_row_ptr;
const int w_times = (out_w - 2) / 3;
float32x4_t vbias = vdupq_n_f32(0.0);
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
auto output_data_tmp = output_data + j * out_h * out_w;
auto input_data_tmp = input_data + j * in_h * in_w;
auto input_const = input_data_tmp;
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
int h_mid = 0;
for (; h_mid < out_h - 1; h_mid++) {
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
} else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
}
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vaddq_f32(res3, vbias);
vst1q_f32(output_row_ptr, res3);
input_row_ptr += 6;
output_row_ptr += 3;
}
}
clock();
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
if (!if_pad) {
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
}
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vaddq_f32(res3, vbias);
if ((w4 != w_times)) {
vst1q_f32(output_row_ptr, res3);
} else {
if (out_l - 2 - w_times * 3 == 1) {
vst1q_lane_f32(output_row_ptr, res3, 0);
} else if (out_l - 2 - w_times * 3 == 2) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
}
}
input_row_ptr += 6;
output_row_ptr += 3;
}
output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 +
input_const[in_l] * w21 +
input_const[in_l + 1] * w22;
out2in_mid = (out_l - 1) * 2;
output_data_tmp[out_l - 1] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w;
output_data_tmp[out_l * (out_l - 1)] =
w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] +
(1 - if_pad) * (w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2;
output_data_tmp[out_l * out_l - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
(1 - if_pad) * (w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
if (if_bias) {
output_data_tmp[0] += bias_data[j];
output_data_tmp[out_l - 1] += bias_data[j];
output_data_tmp[out_l * (out_l - 1)] += bias_data[j];
output_data_tmp[out_l * out_l - 1] += bias_data[j];
}
for (int i = 1; i < out_h - 1; i++) {
out2in_mid = i * 2 * in_w;
output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] +
w12 * input_const[out2in_mid + 1] +
w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1];
out2in_mid = i * 2 * in_w + (out_l - 1) * 2;
output_data_tmp[i * out_l + out_l - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
if (if_bias) {
output_data_tmp[i * out_l] += bias_data[j];
output_data_tmp[i * out_l + out_l - 1] += bias_data[j];
}
}
filter_data_tmp += 9;
}
input_data += inhxw * c;
output_data += outhxw * c;
}
}
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu) {
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]);
const int out_l = out_h;
const int in_l = in_h;
const int inhxw = in_h * in_w;
const int outhxw = out_h * out_w;
const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const float *input_row_ptr;
float *output_row_ptr;
const int w_times = (out_w - 2) / 3;
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
auto output_data_tmp = output_data + j * out_h * out_w;
auto input_data_tmp = input_data + j * in_h * in_w;
auto input_const = input_data_tmp;
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
int h_mid = 0;
for (; h_mid < out_h - 1; h_mid++) {
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
} else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
}
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
vst1q_f32(output_row_ptr, res3);
input_row_ptr += 6;
output_row_ptr += 3;
}
}
clock();
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
if (!if_pad) {
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
}
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
if ((w4 != w_times)) {
vst1q_f32(output_row_ptr, res3);
} else {
if (out_l - 2 - w_times * 3 == 1) {
vst1q_lane_f32(output_row_ptr, res3, 0);
} else if (out_l - 2 - w_times * 3 == 2) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
}
}
input_row_ptr += 6;
output_row_ptr += 3;
}
output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 +
input_const[in_l] * w21 +
input_const[in_l + 1] * w22;
out2in_mid = (out_l - 1) * 2;
output_data_tmp[out_l - 1] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w;
output_data_tmp[out_l * (out_l - 1)] =
w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] +
(1 - if_pad) * (w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2;
output_data_tmp[out_l * out_l - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
(1 - if_pad) * (w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
output_data_tmp[0] =
output_data_tmp[0] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_l - 1] =
output_data_tmp[out_l - 1] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_l * (out_l - 1)] =
output_data_tmp[out_l * (out_l - 1)] * newscale_data[j] +
newbias_data[j];
output_data_tmp[out_l * out_l - 1] =
output_data_tmp[out_l * out_l - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0];
output_data_tmp[out_l - 1] =
output_data_tmp[out_l - 1] < 0 ? 0 : output_data_tmp[out_l - 1];
output_data_tmp[out_l * (out_l - 1)] =
output_data_tmp[out_l * (out_l - 1)] < 0
? 0
: output_data_tmp[out_l * (out_l - 1)];
output_data_tmp[out_l * out_l - 1] =
output_data_tmp[out_l * out_l - 1] < 0
? 0
: output_data_tmp[out_l * out_l - 1];
}
for (int i = 1; i < out_h - 1; i++) {
out2in_mid = i * 2 * in_w;
output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] +
w12 * input_const[out2in_mid + 1] +
w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1];
out2in_mid = i * 2 * in_w + (out_l - 1) * 2;
output_data_tmp[i * out_l + out_l - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
output_data_tmp[i * out_l] =
output_data_tmp[i * out_l] * newscale_data[j] + newbias_data[j];
output_data_tmp[i * out_l + out_l - 1] =
output_data_tmp[i * out_l + out_l - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data_tmp[i * out_l] =
output_data_tmp[i * out_l] < 0 ? 0 : output_data_tmp[i * out_l];
output_data_tmp[i * out_l + out_l - 1] =
output_data_tmp[i * out_l + out_l - 1] < 0
? 0
: output_data_tmp[i * out_l + out_l - 1];
}
}
filter_data_tmp += 9;
}
input_data += inhxw * c;
output_data += outhxw * c;
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -38,6 +38,11 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -22,16 +22,10 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
int MC = 0;
int KC = 0;
int NC = 0;
float *packedA;
float *packedB;
float *packedC;
float *zero;
std::vector<Gemmer *> Gemmer::gemmers;
// 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
void Gemmer::PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
int i, j;
const float *Aij;
......@@ -58,7 +52,7 @@ void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
}
// 将A矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
void Gemmer::PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const float *a0, *a1, *a2, *a3;
for (int i = 0; i < m - m_tail; i += MR) {
......@@ -98,7 +92,7 @@ void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
}
// 将B矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void Gemmer::PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
int i, j;
const float *Bj, *Bj1, *Bj2, *Bj3;
......@@ -127,7 +121,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
}
// 将B矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
void Gemmer::PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const float *b0;
for (int j = 0; j < n - n_tail; j += NR) {
......@@ -156,8 +150,9 @@ void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
}
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu) {
void Gemmer::InnerKernel(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu) {
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
......@@ -184,9 +179,10 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
}
// 分块矩阵乘法
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
void Gemmer::InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *new_scale,
float *new_bias) {
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
......@@ -202,7 +198,8 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
}
#if defined(IOS)
void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) {
void Gemmer::AddDot4x4(int k, const float *a, const float *b, float *C,
int ldc) {
// init C
float32x4_t cv0 = vdupq_n_f32(0.0);
float32x4_t cv1 = vdupq_n_f32(0.0);
......@@ -253,7 +250,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) {
} // namespace math
#elif defined(ARMV7)
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
void Gemmer::AddDot4x4(int k, const float *a, const float *b, float *c,
int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
......@@ -324,7 +322,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
}
#else
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
void Gemmer::AddDot4x4(int k, const float *a, const float *b, float *c,
int ldc) {
float *c0, *c1, *c2, *c3;
c0 = c;
c1 = c + ldc;
......@@ -363,8 +362,9 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
#endif
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu) {
void Gemmer::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 30 * 1024;
......@@ -415,9 +415,10 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(zero);
}
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
void Gemmer::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale,
float *new_bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 30 * 1024;
......@@ -458,8 +459,7 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
mc = s_min(m - i, MC);
PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA);
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + ldc * i + j,
new_bias + ldc * i + j);
&C(i, j), ldc, relu, new_scale + i, new_bias + i);
}
}
......@@ -469,9 +469,9 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(zero);
}
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu) {
void Gemmer::VectorKernel(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta,
float *C, int ldc, bool relu) {
float *bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n));
const float *a0, *b0, *b1, *b2, *b3;
......@@ -691,9 +691,10 @@ void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
}
}
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias) {
void Gemmer::VectorKernelWithBn(int m, int n, int k, float alpha,
const float *A, int lda, const float *B,
int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
float *bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n));
const float *a0, *b0, *b1, *b2, *b3;
......@@ -902,7 +903,8 @@ void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
}
}
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
void Gemmer::AddDot4x8(int k, const float *a, const float *b, float *c,
int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
......@@ -1010,7 +1012,7 @@ void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
}
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
void Gemmer::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int step = 4 * ldc;
......@@ -1067,10 +1069,10 @@ void WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
}
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
void Gemmer::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
void Gemmer::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int step = 4 * ldc;
......@@ -1134,7 +1136,7 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
}
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
void Gemmer::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int step = 4 * ldc;
......@@ -1208,8 +1210,8 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
}
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
float *bias) {
void Gemmer::WriteWithBn(int mc, int nc, float *c, float *C, int ldc,
float *scale, float *bias) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int nc2 = _nc1 / 4;
......@@ -1224,23 +1226,27 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
"mov r5, %[nc1] \n\t"
"mov r6, %[nc2] \n\t"
"vld1.32 {d0}, [%[scale]] \n\t"
"vld1.32 {d1}, [%[bias]] \n\t"
"vdup.32 q1, d0[0] \n\t"
"vdup.32 q2, d1[0] \n\t"
"subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t"
"vld1.32 {q2, q3}, [%[scale]]! \n\t"
"vld1.32 {q10, q11}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q2 \n\t"
"vmla.f32 q11, q1, q3 \n\t"
"vld1.32 {q3, q4}, [%[c]]! \n\t"
"vmul.f32 q10, q3, q1 \n\t"
"vmul.f32 q11, q4, q1 \n\t"
"vadd.f32 q10, q10, q2 \n\t"
"vadd.f32 q11, q11, q2 \n\t"
"vst1.32 {q10, q11}, [%[C]]! \n\t"
"vld1.32 {q4, q5}, [%[c]]! \n\t"
"vld1.32 {q6, q7}, [%[scale]]! \n\t"
"vld1.32 {q12, q13}, [%[bias]]! \n\t"
"vmla.f32 q12, q4, q6 \n\t"
"vmla.f32 q13, q5, q7 \n\t"
"vld1.32 {q5, q6}, [%[c]]! \n\t"
"vmul.f32 q12, q5, q1 \n\t"
"vmul.f32 q13, q6, q1 \n\t"
"vadd.f32 q12, q12, q2 \n\t"
"vadd.f32 q13, q13, q2 \n\t"
"vst1.32 {q12, q13}, [%[C]]! \n\t"
"subs r5, r5, #1 \n\t"
......@@ -1251,10 +1257,9 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
"blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t"
"vld1.32 {q0}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vld1.32 {q7}, [%[c]]! \n\t"
"vmul.f32 q10, q7, q1 \n\t"
"vadd.f32 q10, q10, q2 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"subs r6, r6, #1 \n\t"
......@@ -1265,20 +1270,17 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
"beq end_nc3_%= \n\t"
"sub %[c], %[c], %[nc3] \n\t"
"sub %[scale], %[scale], %[nc3] \n\t"
"sub %[bias], %[bias], %[nc3] \n\t"
"sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q0}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"vld1.32 {q8}, [%[c]]! \n\t"
"vmul.f32 q11, q8, q1 \n\t"
"vadd.f32 q11, q11, q2 \n\t"
"vst1.32 {q11}, [%[C]]! \n\t"
"end_nc3_%=: \n\t"
"add %[scale], %[scale], #4 \n\t"
"add %[bias], %[bias], #4 \n\t"
"add %[c], %[c], %[step1] \n\t"
"add %[scale], %[scale], %[step] \n\t"
"add %[bias], %[bias], %[step] \n\t"
"add %[C], %[C], %[step] \n\t"
"subs %[mc], %[mc], #1 \n\t"
......@@ -1289,13 +1291,13 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
: [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2),
[nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1),
[scale] "r"(scale), [bias] "r"(bias)
: "memory", "cc", "r5", "r6", "r7", "r8", "q0", "q1", "q2", "q3", "q4",
"q5", "q6", "q7", "q10", "q11", "q12", "q13");
: "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q10", "q11", "q12", "q13");
}
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
float *bias) {
void Gemmer::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *scale, float *bias) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int nc2 = _nc1 / 4;
......@@ -1311,25 +1313,29 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
"mov r5, %[nc1] \n\t"
"mov r6, %[nc2] \n\t"
"vld1.32 {d0}, [%[scale]] \n\t"
"vld1.32 {d1}, [%[bias]] \n\t"
"vdup.32 q1, d0[0] \n\t"
"vdup.32 q2, d1[0] \n\t"
"subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t"
"vld1.32 {q2, q3}, [%[scale]]! \n\t"
"vld1.32 {q10, q11}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q2 \n\t"
"vmla.f32 q11, q1, q3 \n\t"
"vld1.32 {q3, q4}, [%[c]]! \n\t"
"vmul.f32 q10, q3, q1 \n\t"
"vmul.f32 q11, q4, q1 \n\t"
"vadd.f32 q10, q10, q2 \n\t"
"vadd.f32 q11, q11, q2 \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vmax.f32 q11, q11, q14 \n\t"
"vst1.32 {q10, q11}, [%[C]]! \n\t"
"vld1.32 {q4, q5}, [%[c]]! \n\t"
"vld1.32 {q6, q7}, [%[scale]]! \n\t"
"vld1.32 {q12, q13}, [%[bias]]! \n\t"
"vmla.f32 q12, q4, q6 \n\t"
"vmla.f32 q13, q5, q7 \n\t"
"vld1.32 {q5, q6}, [%[c]]! \n\t"
"vmul.f32 q12, q5, q1 \n\t"
"vmul.f32 q13, q6, q1 \n\t"
"vadd.f32 q12, q12, q2 \n\t"
"vadd.f32 q13, q13, q2 \n\t"
"vmax.f32 q12, q12, q14 \n\t"
"vmax.f32 q13, q13, q14 \n\t"
"vst1.32 {q12, q13}, [%[C]]! \n\t"
......@@ -1342,10 +1348,9 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
"blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t"
"vld1.32 {q0}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vld1.32 {q7}, [%[c]]! \n\t"
"vmul.f32 q10, q7, q1 \n\t"
"vadd.f32 q10, q10, q2 \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
......@@ -1357,21 +1362,18 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
"beq end_nc3_%= \n\t"
"sub %[c], %[c], %[nc3] \n\t"
"sub %[scale], %[scale], %[nc3] \n\t"
"sub %[bias], %[bias], %[nc3] \n\t"
"sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q0}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"vld1.32 {q8}, [%[c]]! \n\t"
"vmul.f32 q11, q8, q1 \n\t"
"vadd.f32 q11, q11, q2 \n\t"
"vmax.f32 q11, q11, q14 \n\t"
"vst1.32 {q11}, [%[C]]! \n\t"
"end_nc3_%=: \n\t"
"add %[scale], %[scale], #4 \n\t"
"add %[bias], %[bias], #4 \n\t"
"add %[c], %[c], %[step1] \n\t"
"add %[scale], %[scale], %[step] \n\t"
"add %[bias], %[bias], %[step] \n\t"
"add %[C], %[C], %[step] \n\t"
"subs %[mc], %[mc], #1 \n\t"
......@@ -1382,12 +1384,12 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
: [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2),
[nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1),
[scale] "r"(scale), [bias] "r"(bias)
: "memory", "r5", "r6", "r7", "r8", "q0", "q1", "q2", "q3", "q4", "q5",
"q6", "q7", "q10", "q11", "q12", "q13", "q14");
: "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q10", "q11", "q12", "q13", "q14");
}
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc) {
void Gemmer::VecWriteBasic(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
int nc2 = _nc1 / 4;
......@@ -1433,10 +1435,10 @@ void VecWriteBasic(int n, float *c, float *C, int ldc) {
}
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {}
void Gemmer::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {}
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc) {
void Gemmer::VecWriteWithAdd(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
......@@ -1474,7 +1476,7 @@ void VecWriteWithAdd(int n, float *c, float *C, int ldc) {
}
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
void Gemmer::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
......@@ -1522,7 +1524,7 @@ void VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
}
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
void Gemmer::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
float *bias) {
int nc1 = n / 16;
int _nc1 = n % 16;
......@@ -1589,8 +1591,8 @@ void VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
}
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *scale,
float *bias) {
void Gemmer::VecWriteWithBnRelu(int n, float *c, float *C, int ldc,
float *scale, float *bias) {
int nc1 = n / 16;
int _nc1 = n % 16;
int nc2 = _nc1 / 4;
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
// 矩阵取值运算宏,假设矩阵按行存储
#define A(i, j) A[(i)*lda + (j)]
......@@ -27,88 +28,111 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
struct Gemmer {
int MC = 0;
int KC = 0;
int NC = 0;
float *packedA;
float *packedB;
float *packedC;
float *zero;
static std::vector<Gemmer *> gemmers;
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 将 A 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
// 将 A 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *new_scale, float *new_bias);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias);
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta,
float *C, int ldc, bool relu, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc);
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc);
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
// 64位 double 矩阵乘法
void dgemm(int m, int n, int k, float alpha, const double *A, int lda,
// 64位 double 矩阵乘法
void dgemm(int m, int n, int k, float alpha, const double *A, int lda,
const double *B, int ldb, float beta, double *C, int ldc);
};
} // namespace math
} // namespace operators
......
......@@ -26,23 +26,14 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu);
Gemmer::gemmers[0]->Sgemm(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu);
}
template <>
......@@ -54,24 +45,15 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>(), new_bias->data<float>());
Gemmer::gemmers[0]->SgemmWithBn(
M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu, new_scale->data<float>(),
new_bias->data<float>());
}
} // namespace math
......
此差异已折叠。
......@@ -15,6 +15,9 @@ limitations under the License. */
#ifdef POOL_OP
#pragma once
#ifdef _OPENMP
#include <omp.h>
#endif
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
......
......@@ -16,6 +16,9 @@ limitations under the License. */
#include "pooling.h"
#include "common/types.h"
#ifdef _OPENMP
#include <omp.h>
#endif
namespace paddle_mobile {
namespace operators {
......@@ -57,8 +60,8 @@ class PoolFunctor<CPU, PoolProcess, T> {
T *output_data = output->mutable_data<T>();
for (int i = 0; i < batch_size; i++) {
// #pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
#pragma omp parallel for
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
......
......@@ -52,8 +52,9 @@ int main() {
}
auto time1 = time();
paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
ldc);
// paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3,
// c,
// ldc);
auto time2 = time();
DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n";
for (int i = 0; i < m * n; ++i) {
......
......@@ -26,16 +26,17 @@ int main() {
auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, optimize);
executor.SetThreadNum(4);
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
auto time3 = time();
for (int i = 0; i < 10; ++i) {
int count = 1;
for (int i = 0; i < count; ++i) {
executor.Predict(input, dims);
}
auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n";
DLOG << "predict cost :" << time_diff(time3, time4) / count << "ms\n";
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册