提交 2fc1a20a 编写于 作者: 朔-望's avatar 朔-望 提交者: GitHub

Merge branch 'develop' into develop

......@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.0)
project(paddle-mobile)
option(DEBUGING "enable debug mode" ON)
option(USE_OPENMP "openmp support" OFF)
option(USE_OPENMP "openmp support" ON)
option(USE_EXCEPTION "use std exception" ON)
option(LOG_PROFILE "log profile" ON)
# select the platform to build
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "io/executor.h"
#include <operators/math/gemm.h>
#include <algorithm>
#include <vector>
#include "common/enforce.h"
......@@ -25,6 +26,9 @@ limitations under the License. */
#include "framework/program/var_desc.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <queue>
#include <utility>
......@@ -403,6 +407,17 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict(
return result_vector;
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::SetThreadNum(int num) {
for (int k = 0; k < std::max(num, 3); ++k) {
operators::math::Gemmer::gemmers.push_back(new operators::math::Gemmer());
}
#ifdef _OPENMP
// omp_set_dynamic(0);
omp_set_num_threads(num);
#endif
}
template class Executor<CPU, Precision::FP32>;
template class Executor<FPGA, Precision::FP32>;
template class Executor<GPU_MALI, Precision::FP32>;
......
......@@ -58,6 +58,8 @@ class Executor {
std::vector<Ptype> Predict(const std::vector<Ptype> &input,
const std::vector<int64_t> &dims);
void SetThreadNum(int num);
protected:
Executor() = default;
void InitMemory();
......
......@@ -79,11 +79,11 @@ class FusionConvAddBNReluOp
#ifdef PADDLE_MOBILE_CPU
//#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
// static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
// new FusionConvAddBNReluMatcher());
//#define FUSION_CONV_ADD_BN_RELU_REGISTER
//#endif
#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
new FusionConvAddBNReluMatcher());
#define FUSION_CONV_ADD_BN_RELU_REGISTER
#endif
#endif
......
......@@ -14,10 +14,14 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP
#pragma once
#if _OPENMP
#include <omp.h>
#endif
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/gemm.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
......@@ -106,9 +110,33 @@ void ConvAddBasic(const FusionConvAddParam &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1));
auto dim_a = filter_slice.dims();
auto dim_b = col_matrix.dims();
auto dim_out = out_slice.dims();
int m = dim_out[0];
int n = dim_out[1];
int k = dim_a[1];
float *output_data = out_slice.data<float>();
int thread_num = 4;
int m1 = m / thread_num;
int m2 = m % thread_num;
#pragma omp parallel for
for (int j = 0; j < thread_num; ++j) {
int row_count = m1;
if (j == thread_num - 1) {
row_count = m1 + m2;
}
math::Gemmer::gemmers[j]->Sgemm(
row_count, n, k, 1, filter_slice.data<float>() + j * m1 * k, k,
col_matrix.data<float>(), n, 1, output_data + j * m1 * n, n, false);
}
// math::matmul<float>(filter_slice, false, col_matrix, false,
// static_cast<float>(1), &out_slice,
// static_cast<float>(1));
}
}
}
......@@ -124,9 +152,15 @@ void ConvAddCompute(const FusionConvAddParam &param) {
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), param.Bias(), param.Output(), true);
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
// param.Output(), false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true);
} else {
ConvAddBasic(param);
}
......
......@@ -26,8 +26,6 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
Tensor bias = *param.Bias();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
auto new_bias_ptr = new_bias.data<float>();
auto new_scale_ptr = new_scale.data<float>();
int axis = param.Axis();
Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
......@@ -106,20 +104,10 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
}
}
/// todo : use neon in special case instead of 2for(300ms)
auto output_ptr = output->data<float>();
for (int c = 0; c < output_matrix_shape[0]; c++) {
int start = c * output_matrix_shape[1];
for (int j = 0; j < output_matrix_shape[1]; j++) {
output_ptr[start + j] =
output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c];
output_ptr[start + j] =
output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias);
}
}
}
......@@ -138,9 +126,12 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), 1);
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvAddBNReluBasic(param);
}
......
......@@ -37,8 +37,12 @@ void DepthwiseConvCompute(const ConvParam &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), &Bias, param.Output(), false);
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), &Bias, param.Output(), false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
Bias, false);
} else {
ConvBasic(param);
}
......
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LRN_OP
#ifdef _OPENMP
#include <omp.h>
#endif
#include "framework/operator.h"
#include "operators/op_param.h"
......@@ -47,6 +49,7 @@ struct LRNFunctor {
std::fill(sqr_buffer_ptr, sqr_buffer_ptr + sqr_buffer.numel(), 0.0);
for (int a = 0; a < N; a++) {
#pragma parallel for
for (int b = 0; b < C; b++) {
for (int index = start; index < end; index++) {
int channel = b + index;
......
......@@ -1010,6 +1010,442 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
output_data += output_batch_stride;
}
}
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias) {
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *bias_data = bias.data<float>();
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]);
const int out_l = out_h;
const int in_l = in_h;
const int inhxw = in_h * in_w;
const int outhxw = out_h * out_w;
const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const float *input_row_ptr;
float *output_row_ptr;
const int w_times = (out_w - 2) / 3;
float32x4_t vbias = vdupq_n_f32(0.0);
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
auto output_data_tmp = output_data + j * out_h * out_w;
auto input_data_tmp = input_data + j * in_h * in_w;
auto input_const = input_data_tmp;
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
int h_mid = 0;
for (; h_mid < out_h - 1; h_mid++) {
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
} else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
}
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vaddq_f32(res3, vbias);
vst1q_f32(output_row_ptr, res3);
input_row_ptr += 6;
output_row_ptr += 3;
}
}
clock();
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
if (!if_pad) {
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
}
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vaddq_f32(res3, vbias);
if ((w4 != w_times)) {
vst1q_f32(output_row_ptr, res3);
} else {
if (out_l - 2 - w_times * 3 == 1) {
vst1q_lane_f32(output_row_ptr, res3, 0);
} else if (out_l - 2 - w_times * 3 == 2) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
}
}
input_row_ptr += 6;
output_row_ptr += 3;
}
output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 +
input_const[in_l] * w21 +
input_const[in_l + 1] * w22;
out2in_mid = (out_l - 1) * 2;
output_data_tmp[out_l - 1] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w;
output_data_tmp[out_l * (out_l - 1)] =
w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] +
(1 - if_pad) * (w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2;
output_data_tmp[out_l * out_l - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
(1 - if_pad) * (w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
if (if_bias) {
output_data_tmp[0] += bias_data[j];
output_data_tmp[out_l - 1] += bias_data[j];
output_data_tmp[out_l * (out_l - 1)] += bias_data[j];
output_data_tmp[out_l * out_l - 1] += bias_data[j];
}
for (int i = 1; i < out_h - 1; i++) {
out2in_mid = i * 2 * in_w;
output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] +
w12 * input_const[out2in_mid + 1] +
w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1];
out2in_mid = i * 2 * in_w + (out_l - 1) * 2;
output_data_tmp[i * out_l + out_l - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
if (if_bias) {
output_data_tmp[i * out_l] += bias_data[j];
output_data_tmp[i * out_l + out_l - 1] += bias_data[j];
}
}
filter_data_tmp += 9;
}
input_data += inhxw * c;
output_data += outhxw * c;
}
}
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu) {
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]);
const int out_l = out_h;
const int in_l = in_h;
const int inhxw = in_h * in_w;
const int outhxw = out_h * out_w;
const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const float *input_row_ptr;
float *output_row_ptr;
const int w_times = (out_w - 2) / 3;
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
auto output_data_tmp = output_data + j * out_h * out_w;
auto input_data_tmp = input_data + j * in_h * in_w;
auto input_const = input_data_tmp;
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
int h_mid = 0;
for (; h_mid < out_h - 1; h_mid++) {
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
} else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
}
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
vst1q_f32(output_row_ptr, res3);
input_row_ptr += 6;
output_row_ptr += 3;
}
}
clock();
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
if (!if_pad) {
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
}
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
if ((w4 != w_times)) {
vst1q_f32(output_row_ptr, res3);
} else {
if (out_l - 2 - w_times * 3 == 1) {
vst1q_lane_f32(output_row_ptr, res3, 0);
} else if (out_l - 2 - w_times * 3 == 2) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
}
}
input_row_ptr += 6;
output_row_ptr += 3;
}
output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 +
input_const[in_l] * w21 +
input_const[in_l + 1] * w22;
out2in_mid = (out_l - 1) * 2;
output_data_tmp[out_l - 1] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w;
output_data_tmp[out_l * (out_l - 1)] =
w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] +
(1 - if_pad) * (w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2;
output_data_tmp[out_l * out_l - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
(1 - if_pad) * (w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
output_data_tmp[0] =
output_data_tmp[0] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_l - 1] =
output_data_tmp[out_l - 1] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_l * (out_l - 1)] =
output_data_tmp[out_l * (out_l - 1)] * newscale_data[j] +
newbias_data[j];
output_data_tmp[out_l * out_l - 1] =
output_data_tmp[out_l * out_l - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0];
output_data_tmp[out_l - 1] =
output_data_tmp[out_l - 1] < 0 ? 0 : output_data_tmp[out_l - 1];
output_data_tmp[out_l * (out_l - 1)] =
output_data_tmp[out_l * (out_l - 1)] < 0
? 0
: output_data_tmp[out_l * (out_l - 1)];
output_data_tmp[out_l * out_l - 1] =
output_data_tmp[out_l * out_l - 1] < 0
? 0
: output_data_tmp[out_l * out_l - 1];
}
for (int i = 1; i < out_h - 1; i++) {
out2in_mid = i * 2 * in_w;
output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] +
w12 * input_const[out2in_mid + 1] +
w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1];
out2in_mid = i * 2 * in_w + (out_l - 1) * 2;
output_data_tmp[i * out_l + out_l - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
output_data_tmp[i * out_l] =
output_data_tmp[i * out_l] * newscale_data[j] + newbias_data[j];
output_data_tmp[i * out_l + out_l - 1] =
output_data_tmp[i * out_l + out_l - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data_tmp[i * out_l] =
output_data_tmp[i * out_l] < 0 ? 0 : output_data_tmp[i * out_l];
output_data_tmp[i * out_l + out_l - 1] =
output_data_tmp[i * out_l + out_l - 1] < 0
? 0
: output_data_tmp[i * out_l + out_l - 1];
}
}
filter_data_tmp += 9;
}
input_data += inhxw * c;
output_data += outhxw * c;
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -38,6 +38,11 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
此差异已折叠。
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
// 矩阵取值运算宏,假设矩阵按行存储
#define A(i, j) A[(i)*lda + (j)]
......@@ -27,88 +28,111 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
struct Gemmer {
int MC = 0;
int KC = 0;
int NC = 0;
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 将 A 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias);
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc);
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
// 64位 double 矩阵乘法
void dgemm(int m, int n, int k, float alpha, const double *A, int lda,
const double *B, int ldb, float beta, double *C, int ldc);
float *packedA;
float *packedB;
float *packedC;
float *zero;
static std::vector<Gemmer *> gemmers;
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 将 A 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *new_scale, float *new_bias);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta,
float *C, int ldc, bool relu, float *new_scale,
float *new_bias);
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc);
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
// 64位 double 矩阵乘法
void dgemm(int m, int n, int k, float alpha, const double *A, int lda,
const double *B, int ldb, float beta, double *C, int ldc);
};
} // namespace math
} // namespace operators
......
......@@ -26,23 +26,14 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu);
Gemmer::gemmers[0]->Sgemm(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu);
}
template <>
......@@ -54,24 +45,15 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>(), new_bias->data<float>());
Gemmer::gemmers[0]->SgemmWithBn(
M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu, new_scale->data<float>(),
new_bias->data<float>());
}
} // namespace math
......
此差异已折叠。
......@@ -15,6 +15,9 @@ limitations under the License. */
#ifdef POOL_OP
#pragma once
#ifdef _OPENMP
#include <omp.h>
#endif
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
......
......@@ -16,6 +16,9 @@ limitations under the License. */
#include "pooling.h"
#include "common/types.h"
#ifdef _OPENMP
#include <omp.h>
#endif
namespace paddle_mobile {
namespace operators {
......@@ -57,8 +60,8 @@ class PoolFunctor<CPU, PoolProcess, T> {
T *output_data = output->mutable_data<T>();
for (int i = 0; i < batch_size; i++) {
// #pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
#pragma omp parallel for
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
......
......@@ -52,8 +52,9 @@ int main() {
}
auto time1 = time();
paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
ldc);
// paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3,
// c,
// ldc);
auto time2 = time();
DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n";
for (int i = 0; i < m * n; ++i) {
......
......@@ -26,16 +26,17 @@ int main() {
auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, optimize);
executor.SetThreadNum(4);
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
auto time3 = time();
for (int i = 0; i < 10; ++i) {
int count = 1;
for (int i = 0; i < count; ++i) {
executor.Predict(input, dims);
}
auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n";
DLOG << "predict cost :" << time_diff(time3, time4) / count << "ms\n";
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册