提交 3e8bfc7b 编写于 作者: C chenjiaoAngel

add 5x5s2_dw

上级 04ff99f6
......@@ -82,6 +82,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s1_depthwise_fp32.cc
conv5x5s2_depthwise_int8.cc
conv5x5s2_depthwise_fp32.cc
conv5x5s2_depthwise_fp32_c4.cc
conv3x3_winograd_fp32_c4.cc
conv3x3_winograd_int8.cc
conv_winograd_3x3.cc
......
此差异已折叠。
......@@ -193,6 +193,25 @@ void conv_depthwise_5x5s2_fp32(const float* din,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_5x5s2_fp32(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int pad_top,
int pad_bottom,
int pad_left,
int pad_right,
const operators::ActivationParam& act_param,
ARMContext* ctx);
void conv_depthwise_5x5s2p2_fp32(const float* din,
float* dout,
int num,
......
......@@ -734,22 +734,45 @@ void conv_depthwise_5x5_fp32(const void* din,
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
bool ch_four = ch_in > 4 * w_in;
bool pads_five = (pad_h < 5) || (pad_w < 5);
ctx->ExtendWorkspace((w_in + w_out) * sizeof(float));
if (stride == 2) {
conv_depthwise_5x5s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
h_out,
w_out,
ch_in,
h_in,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
if (ch_four || !pads_five || h_in < 5 || w_in < 10) {
conv_depthwise_5x5s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
h_out,
w_out,
ch_in,
h_in,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
else {
conv_depthwise_5x5s2_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights),
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
paddings[0],
paddings[1],
paddings[2],
paddings[3],
act_param,
ctx);
}
} else if (stride == 1) {
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din),
......
......@@ -28,7 +28,11 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims();
auto kw = w_dims[3];
auto channel = w_dims[0];
auto hin = param.x->dims()[2];
auto win = param.x->dims()[3];
auto paddings = *param.paddings;
bool ch_four = channel <= 4 * win;
// select dw conv kernel
if (kw == 3) {
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
......@@ -54,7 +58,15 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
#endif
} else if (kw == 5) {
auto strides = param.strides;
if ((strides[0] == 1 && strides[1] == 1) ||
bool pads_five = (paddings[0] < 5) || (paddings[2] < 5);
if (ch_four && pads_five && win >= 2 * kw && hin >= kw &&
(strides[0] == 2 && strides[1] == 2) {
flag_trans_weights_ = false;
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_fp32";
#endif
} else if ((strides[0] == 1 && strides[1] == 1) ||
(strides[0] == 2 && strides[1] == 2)) {
// trans weights
constexpr int cblock = 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册