提交 f656097b 编写于 作者: E eclipsess

dwconv 3x3 s1p1 w!=h

上级 f10af946
...@@ -118,8 +118,7 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) { ...@@ -118,8 +118,7 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Input()->dims()[2] == param.Input()->dims()[3]) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true); param.Bias(), true);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
......
...@@ -124,8 +124,7 @@ void ConvCompute(const ConvParam<CPU> &param) { ...@@ -124,8 +124,7 @@ void ConvCompute(const ConvParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Input()->dims()[2] == param.Input()->dims()[3]) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false); nullptr, false);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
......
...@@ -30,8 +30,7 @@ void DepthwiseConvCompute(const ConvParam<CPU> &param) { ...@@ -30,8 +30,7 @@ void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Bias.mutable_data<float>({param.Groups()}); Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Input()->dims()[2] == param.Input()->dims()[3]) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
&Bias, false); &Bias, false);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
......
...@@ -257,8 +257,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -257,8 +257,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
const int h = static_cast<int>(input->dims()[2]); const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]); const int w = static_cast<int>(input->dims()[3]);
const int l = h; // const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]); const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w; const int hxw = h * w;
...@@ -271,7 +270,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -271,7 +270,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
vbias = vdupq_n_f32(bias_data[j]); vbias = vdupq_n_f32(bias_data[j]);
} }
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0]; float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1]; float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2]; float w02 = filter_data_tmp[2];
...@@ -283,39 +282,38 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -283,39 +282,38 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
float w22 = filter_data_tmp[8]; float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] + output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[l] + w22 * input_data[l + 1]; w21 * input_data[w] + w22 * input_data[w + 1];
output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] +
w20 * input_data[2 * l - 2] + w20 * input_data[2 * w - 2] +
w21 * input_data[2 * l - 1]; w21 * input_data[2 * w - 1];
output_data[(l - 1) * l] = output_data[(h - 1) * w] =
w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + output_data[h * w - 1] =
w01 * input_data[(l - 2) * (l + 1) + 1] + w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] +
w10 * input_data[l * l - 2] + w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1];
w11 * input_data[l * l - 1];
if (if_bias) { if (if_bias) {
output_data[0] += bias_data[j]; output_data[0] += bias_data[j];
output_data[l - 1] += bias_data[j]; output_data[w - 1] += bias_data[j];
output_data[(l - 1) * l] += bias_data[j]; output_data[(h - 1) * w] += bias_data[j];
output_data[l * l - 1] += bias_data[j]; output_data[h * w - 1] += bias_data[j];
} }
for (int i = 1; i < l - 1; ++i) { for (int i = 1; i < h - 1; ++i) {
output_data[i * l] = output_data[i * w] =
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] + w11 * input_data[i * w] + w12 * input_data[i * w + w] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1];
output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] +
w01 * input_data[i * l + l - 1 - l] + w01 * input_data[i * w + w - 1 - w] +
w10 * input_data[i * l + l - 1 - 1] + w10 * input_data[i * w + w - 1 - 1] +
w11 * input_data[i * l + l - 1] + w11 * input_data[i * w + w - 1] +
w20 * input_data[i * l + l - 1 + l - 1] + w20 * input_data[i * w + w - 1 + w - 1] +
w21 * input_data[i * l + l - 1 + l]; w21 * input_data[i * w + w - 1 + w];
if (if_bias) { if (if_bias) {
output_data[i * l] += bias_data[j]; output_data[i * w] += bias_data[j];
output_data[i * l + l - 1] += bias_data[j]; output_data[i * w + w - 1] += bias_data[j];
} }
} }
...@@ -325,15 +323,15 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -325,15 +323,15 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0; tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp); in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + l); in2 = vld1q_f32(input_tmp + w);
const float *input_tmp_end = input_tmp + (l - 2) * l; const float *input_tmp_end = input_tmp + (h - 2) * w;
in4 = vld1q_f32(input_tmp_end); in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + l); in6 = vld1q_f32(input_tmp_end + w);
int c_mid = l_mid; int c_mid = w_mid;
auto output_ptr = output_data + 1; auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) { for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4); in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + l + 4); in3 = vld1q_f32(input_tmp + w + 4);
tmp0 = vextq_f32(in0, in1, 1); tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2); tmp1 = vextq_f32(in0, in1, 2);
...@@ -352,7 +350,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -352,7 +350,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
vst1q_f32(output_ptr, out0); vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4); in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + l + 4); in7 = vld1q_f32(input_tmp_end + w + 4);
tmp0 = vextq_f32(in4, in5, 1); tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2); tmp1 = vextq_f32(in4, in5, 2);
...@@ -367,7 +365,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -367,7 +365,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias); out0 = vaddq_f32(out0, vbias);
vst1q_f32(output_ptr + (l - 1) * l, out0); vst1q_f32(output_ptr + (h - 1) * w, out0);
// can optimize to each 8 stride. // can optimize to each 8 stride.
input_tmp += 4; input_tmp += 4;
...@@ -380,8 +378,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -380,8 +378,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
} }
// top right pad // top right pad
float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]);
tmp0 = vextq_f32(in0, pad0, 1); tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2); tmp1 = vextq_f32(in0, pad0, 2);
...@@ -409,8 +407,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -409,8 +407,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
} }
// bottom right pad // bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]);
float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]);
tmp0 = vextq_f32(in4, pad2, 1); tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2); tmp1 = vextq_f32(in4, pad2, 2);
...@@ -427,28 +425,28 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -427,28 +425,28 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
for (int i = 0; i < c_mid; ++i) { for (int i = 0; i < c_mid; ++i) {
if (i == 0) { if (i == 0) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0);
} }
if (i == 1) { if (i == 1) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1);
} }
if (i == 2) { if (i == 2) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2);
} }
} }
// mid // mid
for (int i = 0; i < l - 2; ++i) { for (int i = 0; i < h - 2; ++i) {
auto output_ptr = output_data + (i + 1) * l + 1; auto output_ptr = output_data + (i + 1) * w + 1;
input_tmp = input_data + i * l; input_tmp = input_data + i * w;
auto in0_tmp = vld1q_f32(input_tmp); auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + l); auto in2_tmp = vld1q_f32(input_tmp + w);
auto in4_tmp = vld1q_f32(input_tmp + l + l); auto in4_tmp = vld1q_f32(input_tmp + w + w);
c_mid = l_mid; c_mid = w_mid;
for (; c_mid > 3; c_mid -= 4) { for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4); auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4); auto in3_tmp = vld1q_f32(input_tmp + w + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); auto in5_tmp = vld1q_f32(input_tmp + w + w + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
...@@ -477,9 +475,9 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -477,9 +475,9 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
in4_tmp = in5_tmp; in4_tmp = in5_tmp;
} }
float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]);
tmp0 = vextq_f32(in0_tmp, pad0, 1); tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2); tmp1 = vextq_f32(in0_tmp, pad0, 2);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册