提交 532cff71 编写于 作者: H hjchen2

Refator depthwise conv3x3 and fix it's bugs for armv8

上级 cb5e15b9
......@@ -61,25 +61,15 @@ template <>
void ConvAddBNReluKernel<CPU, float>::Compute(
const FusionConvAddBNReluParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_add_arm_func.h"
namespace paddle_mobile {
......@@ -21,12 +22,44 @@ namespace operators {
template <>
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) {
InitBaseConvKernel(param);
return true;
}
template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
ConvAddCompute<float>(param);
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvAddBasic(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvAddKernel<CPU, float>;
......
......@@ -31,21 +31,14 @@ template <>
void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
#ifndef __aarch64__
......
......@@ -16,7 +16,8 @@ limitations under the License. */
#include "operators/kernel/conv_bn_add_relu_kernel.h"
#include <cmath>
#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -51,13 +52,46 @@ bool ConvBNAddReluKernel<CPU, float>::Init(
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
InitBaseConvKernel(param);
return true;
}
template <>
void ConvBNAddReluKernel<CPU, float>::Compute(
const FusionConvBNAddReluParam<CPU> &param) {
ConvBNAddReluCompute<float>(param);
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionConvBNAddReluParam<CPU>>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvBNAddReluKernel<CPU, float>;
......
......@@ -60,25 +60,15 @@ template <>
void ConvBNReluKernel<CPU, float>::Compute(
const FusionConvBNReluParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
......
......@@ -44,29 +44,25 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
#endif // __aarch64__
} else {
if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 0 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
} else if (depth3x3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
param->Strides()[0] == 2) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT;
#ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
} else if (conv3x3 && !depth3x3 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 /* &&
param->Output()->dims()[1] >= 16 &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1
#if 0
&& param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */) {
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */
#endif
) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
param->transformed_filter_ = new framework::LoDTensor;
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <iostream>
namespace paddle_mobile {
namespace operators {
......@@ -41,21 +43,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConv5x5<int8_t, int32_t>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
......
......@@ -60,25 +60,15 @@ template <>
void DWConvBNReluKernel<CPU, float>::Compute(
const FusionDWConvBNReluParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
......
......@@ -115,35 +115,6 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
}
}
template <typename P>
void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
param.Output()->mutable_data<float>();
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
// param.Output(), false);
if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, false);
} else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, false);
}
} else {
ConvAddBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -115,31 +115,6 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) {
}
}
}
template <typename P>
void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvBNAddReluBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&out_slice);
} else if (data_dim == 3U) {
col2vol(col, dilations, strides, paddings, &out_slice);
}
......
......@@ -23,48 +23,6 @@ namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input,
......
......@@ -31,16 +31,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
// return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
// }
CPUInfo *info = CPUInfo::Info();
GemmExecutor<SgemmStrategy> exec(info, transA, transB, M, N, K);
GemmExecutor<SgemmStrategy> exec(transA, transB, M, N, K);
exec(alpha, A, lda, B, ldb, beta, C, ldc);
}
void cblas_sgemv(const bool trans, const int M, const int N, const float alpha,
const float *A, const int lda, const float *B,
const float beta, float *C) {
CPUInfo *info = CPUInfo::Info();
GemvExecutor<SgemvStrategy> exec(info, trans, M, N);
GemvExecutor<SgemvStrategy> exec(trans, M, N);
exec(alpha, A, lda, B, beta, C);
}
......
......@@ -19,7 +19,6 @@ limitations under the License. */
#include <omp.h>
#endif
#include <sys/time.h>
#include <iostream>
#include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm/cpu_info.h"
......@@ -29,6 +28,8 @@ namespace paddle_mobile {
namespace operators {
namespace math {
static CPUInfo *info = CPUInfo::Info();
int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; }
unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num,
const int N, const int K) {
......@@ -62,15 +63,9 @@ class GemmExecutor : public Executor {
typedef typename Strategy::Otype Otype;
public:
GemmExecutor(const CPUInfo *info, const bool transA, const bool transB,
const int M, const int N, const int K)
: Executor(),
info_(info),
transA_(transA),
transB_(transB),
M_(M),
N_(N),
K_(K) {
GemmExecutor(const bool transA, const bool transB, const int M, const int N,
const int K)
: Executor(), transA_(transA), transB_(transB), M_(M), N_(N), K_(K) {
unsigned int L1_size = 0;
unsigned int L2_size = 0;
if (M_ > N_) {
......@@ -212,8 +207,6 @@ class GemmExecutor : public Executor {
virtual ~GemmExecutor() {}
private:
const CPUInfo *info_;
const unsigned int M_;
const unsigned int N_;
const unsigned int K_;
......@@ -242,8 +235,8 @@ class GemvExecutor : public Executor {
typedef typename Strategy::Otype Otype;
public:
GemvExecutor(const CPUInfo *info, const bool transA, const int M, const int N)
: Executor(), info_(info), M_(M), N_(N) {}
GemvExecutor(const bool transA, const int M, const int N)
: Executor(), M_(M), N_(N) {}
void operator()(const float alpha, const Itype *A, const int lda,
const Itype *B, const float beta, Otype *C) {
......@@ -253,8 +246,6 @@ class GemvExecutor : public Executor {
virtual ~GemvExecutor() {}
private:
const CPUInfo *const info_;
const unsigned int M_;
const unsigned int N_;
......
......@@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
for (int i = start_height; i < end_height; i += stride_h) {
if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(float));
// memcpy(col_data, im_data, extract * sizeof(float));
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4_t img = vld1q_f32(im_data + s);
vst1q_f32(col_data + s, img);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s];
}
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
......@@ -109,325 +119,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
const float *im_data = im.data<float>();
float *col_data = col->data<float>();
#if __ARM_NEON
const int osize = col_height;
const int isize = im_height;
bool pad1 = padding[0] > 0;
bool pad2 =
(pad1 && padding[1] &&
(((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0));
int fill = isize % 2;
if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 &&
dilation[0] == 1 && im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
float *col0 = col_data + 0 * oosize + 2 * osize + 2;
float *col1 = col_data + 1 * oosize + 2 * osize + 1;
float *col2 = col_data + 2 * oosize + 2 * osize;
float *col3 = col_data + 3 * oosize + osize + 2;
float *col4 = col_data + 4 * oosize + osize + 1;
float *col5 = col_data + 5 * oosize + osize;
float *col6 = col_data + 6 * oosize + 2;
float *col7 = col_data + 7 * oosize + 1;
float *col8 = col_data + 8 * oosize;
float32x4_t im1;
const float *im_tmp_data = im_data + osize + 1;
int rrsize = oosize - osize - 1;
int nr4 = rrsize / 4;
int mr4 = rrsize % 4;
for (int i = 0; i < nr4; ++i) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
vst1q_f32(col6, im1);
vst1q_f32(col7, im1);
vst1q_f32(col8, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mr4; ++i) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
*col6 = *im_tmp_data;
*col7 = *im_tmp_data;
*col8 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data++;
}
im_tmp_data = im_data + 1;
col0 = col_data + 0 * oosize + osize + 2;
col1 = col_data + 1 * oosize + osize + 1;
col2 = col_data + 2 * oosize + osize;
col3 = col_data + 3 * oosize + 2;
col4 = col_data + 4 * oosize + 1;
col5 = col_data + 5 * oosize;
for (int i = 0; i < nk4; i++) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mk4; i++) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
im_tmp_data++;
}
// fill 0 1 11;
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
col_data[0 * oosize + osize + 1] = im_data[0];
col_data[3 * oosize + 1] = im_data[0];
col_data[6 * oosize + 1] = im_data[osize];
col_data[1 * oosize + osize] = im_data[0];
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[osize];
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 &&
im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
// 3 2 3 1 0 1 3 2 3
float *col0 = col_data + 0 * oosize + osize + 1;
float *col1 = col_data + 1 * oosize + osize;
float *col2 = col_data + 2 * oosize + osize;
float *col3 = col_data + 3 * oosize + 1;
float *col4 = col_data + 4 * oosize;
float *col5 = col_data + 5 * oosize;
float *col6 = col_data + 6 * oosize + 1;
float *col7 = col_data + 7 * oosize;
float *col8 = col_data + 8 * oosize;
float32x4x2_t im01;
float32x4x2_t im23;
const float *im_tmp_data0 = im_data;
const float *im_tmp_data2 = im_data + isize;
for (int j = 0; j < osize; ++j) {
for (int i = 0; i < nk4; ++i) {
im01 = vld2q_f32(im_tmp_data0);
im23 = vld2q_f32(im_tmp_data2);
vst1q_f32(col0, im23.val[1]);
vst1q_f32(col1, im23.val[0]);
vst1q_f32(col2, im23.val[1]);
vst1q_f32(col3, im01.val[1]);
vst1q_f32(col4, im01.val[0]);
vst1q_f32(col5, im01.val[1]);
vst1q_f32(col6, im23.val[1]);
vst1q_f32(col7, im23.val[0]);
vst1q_f32(col8, im23.val[1]);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data0 += 8;
im_tmp_data2 += 8;
}
const float *im_tmp_data1 = im_tmp_data0 + 1;
const float *im_tmp_data3 = im_tmp_data2 + 1;
for (int i = 0; i < mk4; ++i) {
*col0 = *im_tmp_data3;
*col1 = *im_tmp_data2;
*col2 = *im_tmp_data3;
*col3 = *im_tmp_data1;
*col4 = *im_tmp_data0;
*col5 = *im_tmp_data1;
*col6 = *im_tmp_data3;
*col7 = *im_tmp_data2;
*col8 = *im_tmp_data3;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data0 += 2;
im_tmp_data1 += 2;
im_tmp_data2 += 2;
im_tmp_data3 += 2;
}
im_tmp_data0 += (isize - fill);
im_tmp_data2 += (isize - fill);
}
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
if (pad2) {
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
}
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
if (pad2) {
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
}
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
if (pad2) {
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
}
col_data[1 * oosize + osize] = im_data[isize];
for (int i = 1; i < osize; ++i) {
col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1];
}
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[isize];
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
......
......@@ -441,10 +441,8 @@ class ConvParam : public OpParam {
enum ExecMode {
EXEC_INVALID = 0,
EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT,
EXEC_DEPTHWISE3x3S2P0_FLOAT,
EXEC_DEPTHWISE3x3S2P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_DEPTHWISE3x3S1_FLOAT,
EXEC_DEPTHWISE3x3S2_FLOAT,
EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT,
EXEC_DEPTHWISE5x5_FLOAT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册