提交 986bcdfc 编写于 作者: H hjchen2

Enable using optimization implementation for conv_add_relu op

上级 c3ae7671
......@@ -15,21 +15,58 @@ limitations under the License. */
#ifdef FUSION_CONVADDRELU_OP
#include "operators/kernel/conv_add_relu_kernel.h"
#include "operators/kernel/central-arm-func/conv_add_relu_arm_func.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam<CPU> *param) {
InitBaseConvKernel(param);
return true;
}
template <>
void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam<CPU> &param) {
ConvAddReluCompute<float, float>(param);
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvAddReluBasic<FusionConvAddReluParam<CPU>>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvAddReluKernel<CPU, float>;
} // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDRELU_OP
#pragma once
#include <operators/math/depthwise_conv3x3.h>
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename Itype, typename Otype>
void ConvAddReluBasic(const FusionConvAddReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
int32_t axis = param.Axis();
Otype *bias_data = bias.data<Otype>();
Tensor *output = param.Output();
output->mutable_data<Otype>();
float alpha = 1.0f;
float beta = 1.0f;
int32_t groups = param.Groups();
std::vector<int32_t> strides = param.Strides();
std::vector<int32_t> paddings = param.Paddings();
std::vector<int32_t> dilations = param.Dilations();
const int32_t batch_size = static_cast<int32_t>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<Itype>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int32_t>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int32_t in_step = static_cast<int32_t>(input->dims()[1]) / groups;
int32_t out_step = static_cast<int32_t>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, Itype> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Itype> im2col;
for (int32_t i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int32_t g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int32_t>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::MatMul<Itype, Otype>(filter_slice, false, col_matrix, false, alpha,
&out_slice, beta, true, bias_data);
}
}
}
template <typename Itype, typename Otype>
void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
param.Output()->mutable_data<float>();
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
// param.Output(), false);
if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
} else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, true);
}
} else {
ConvAddReluBasic<Itype, Otype>(param);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -212,6 +212,100 @@ inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
}
#endif // __aarch64__
template <typename ParamType>
void ConvAddReluBasic(const ParamType &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
Tensor *output = param.Output();
output->mutable_data<float>();
float alpha = 1.0f;
float beta = 1.0f;
int32_t groups = param.Groups();
int32_t axis = param.Axis();
std::vector<int32_t> strides = param.Strides();
std::vector<int32_t> paddings = param.Paddings();
std::vector<int32_t> dilations = param.Dilations();
const int32_t batch_size = static_cast<int32_t>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int32_t>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int32_t in_step = static_cast<int32_t>(input->dims()[1]) / groups;
int32_t out_step = static_cast<int32_t>(output->dims()[1]) / groups;
float *bias_data = bias.data<float>();
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int32_t i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int32_t g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col_matrix = in_slice;
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int32_t>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::MatMul<float, float>(filter_slice, false, col_matrix, false, alpha,
&out_slice, beta, true, bias_data);
}
}
}
template <typename ParamType>
void ConvBNReluBasic(const ParamType &param) {
const Tensor *input = param.Input();
......
......@@ -99,6 +99,58 @@ inline bool IsExpand(const std::vector<int64_t> &filter_dim,
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
template <ActivationType Act>
void AddChannelWise(const framework::Tensor *input,
const framework::Tensor *bias, framework::Tensor *output) {
const float *input_ptr = input->data<float>();
const float *bias_ptr = bias->data<float>();
float *output_ptr = output->mutable_data<float>();
// maybe check shape
int batch_size = input->dims()[0];
int channels = input->dims()[1];
size_t spatial_size = input->dims()[2] * input->dims()[3];
for (int batch = 0; batch < batch_size; ++batch) {
for (int channel = 0; channel < channels; ++channel) {
size_t offset = (batch * channels + channel) * spatial_size;
const float *x = input_ptr + offset;
float *y = output_ptr + offset;
float beta = bias_ptr[channel];
int j = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4_t __bias = vdupq_n_f32(beta);
for (; j < spatial_size - 15; j += 16, x += 16, y += 16) {
float32x4_t in0 = vld1q_f32(x);
float32x4_t in1 = vld1q_f32(x + 4);
float32x4_t in2 = vld1q_f32(x + 8);
float32x4_t in3 = vld1q_f32(x + 12);
in0 = vaddq_f32(__bias, in0);
in1 = vaddq_f32(__bias, in1);
in2 = vaddq_f32(__bias, in2);
in3 = vaddq_f32(__bias, in3);
in0 = math::vActiveq_f32<Act>(in0);
in1 = math::vActiveq_f32<Act>(in1);
in2 = math::vActiveq_f32<Act>(in2);
in3 = math::vActiveq_f32<Act>(in3);
vst1q_f32(y, in0);
vst1q_f32(y + 4, in1);
vst1q_f32(y + 8, in2);
vst1q_f32(y + 12, in3);
}
for (; j < spatial_size - 3; j += 4, x += 4, y += 4) {
float32x4_t in0 = vld1q_f32(x);
in0 = vaddq_f32(__bias, in0);
in0 = math::vActiveq_f32<Act>(in0);
vst1q_f32(y, in0);
}
#endif
for (; j < spatial_size; ++j, ++x, ++y) {
*y = math::Active<Act>((*x) + beta);
}
}
}
}
template <ActivationType Act>
void ScaleAddChannelWise(const framework::Tensor *input,
const framework::Tensor *scale,
......
......@@ -61,7 +61,7 @@ class GemmExecutor : public Executor {
K_(K) {
unsigned int L1_size = info->L1_cache;
unsigned int L2_size = info->L2_cache;
// if (N_ > 10000) L1_size *= 2;
if (N_ > 30000 && K_ > 100) L1_size *= 2;
if (num_threads_ >= 2) L1_size /= 2;
rhs_tile_num_ = L1_size / (K * sizeof(Itype));
......@@ -74,8 +74,8 @@ class GemmExecutor : public Executor {
rhs_tile_num_ *= Strategy::out_width();
}
// lhs_tile_num_ = CeilDiv(M, Strategy::out_height()) *
// Strategy::out_height();
// lhs_tile_num_ = CeilDiv(M, Strategy::out_height()) *
// Strategy::out_height();
lhs_tile_num_ = L2_size / (K * sizeof(Itype));
if (lhs_tile_num_ == 0) {
lhs_tile_num_ = Strategy::out_height();
......@@ -90,8 +90,8 @@ class GemmExecutor : public Executor {
void operator()(const float alpha, const Itype *A, const int lda,
const Itype *B, const int ldb, const float beta, Otype *C,
const int ldc) {
// struct timeval tv_begin, tv_end;
// gettimeofday(&tv_begin,NULL);
// struct timeval tv_begin, tv_end;
// gettimeofday(&tv_begin,NULL);
int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height();
lhs_worksize_ = sizeof(Itype) * mblock * K_;
......@@ -107,9 +107,10 @@ class GemmExecutor : public Executor {
strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true);
// std::cout << "M: " << M_ << ", N: " << N_ << ", K: " << K_ <<
// std::endl; std::cout << "rhs_block: " << CeilDiv(N_, rhs_tile_num_) <<
// std::endl;
// std::cout << "M: " << M_ << ", N: " << N_
// << ", K: " << K_ << std::endl;
// std::cout << "rhs_block: " << CeilDiv(N_, rhs_tile_num_)
// << std::endl;
#pragma omp parallel for if (N_ > 128)
for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) {
......@@ -145,11 +146,12 @@ class GemmExecutor : public Executor {
paddle_mobile::memory::Free(rhs_workspace_);
paddle_mobile::memory::Free(out_workspace_);
// gettimeofday(&tv_end,NULL);
// float elapsed = (tv_end.tv_sec - tv_begin.tv_sec) * 1000.f +
// (tv_end.tv_usec - tv_begin.tv_usec) / 1000.f; std::cout << "elapsed: "
// << elapsed << "ms, speed: " << (M_ * N_ * K_ / 1000.f / 1000.f) /
// elapsed << " gflops" << std::endl;
// gettimeofday(&tv_end,NULL);
// float elapsed = (tv_end.tv_sec - tv_begin.tv_sec) * 1000.f +
// (tv_end.tv_usec - tv_begin.tv_usec) / 1000.f;
// std::cout << "elapsed: " << elapsed << "ms, speed: "
// << (M_ * N_ * K_ / 1000.f / 1000.f) / elapsed
// << " gflops" << std::endl;
}
virtual ~GemmExecutor() {}
......@@ -189,7 +191,7 @@ class GemvExecutor : public Executor {
void operator()(const float alpha, const Itype *A, const int lda,
const Itype *B, const float beta, Otype *C) {
// strategy_.kernel();
// strategy_.kernel();
}
virtual ~GemvExecutor() {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册