提交 532cff71 编写于 作者: H hjchen2

Refator depthwise conv3x3 and fix it's bugs for armv8

上级 cb5e15b9
......@@ -61,25 +61,15 @@ template <>
void ConvAddBNReluKernel<CPU, float>::Compute(
const FusionConvAddBNReluParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_add_arm_func.h"
namespace paddle_mobile {
......@@ -21,12 +22,44 @@ namespace operators {
template <>
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) {
InitBaseConvKernel(param);
return true;
}
template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
ConvAddCompute<float>(param);
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvAddBasic(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvAddKernel<CPU, float>;
......
......@@ -31,21 +31,14 @@ template <>
void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
break;
#ifndef __aarch64__
......
......@@ -16,7 +16,8 @@ limitations under the License. */
#include "operators/kernel/conv_bn_add_relu_kernel.h"
#include <cmath>
#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -51,13 +52,46 @@ bool ConvBNAddReluKernel<CPU, float>::Init(
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
InitBaseConvKernel(param);
return true;
}
template <>
void ConvBNAddReluKernel<CPU, float>::Compute(
const FusionConvBNAddReluParam<CPU> &param) {
ConvBNAddReluCompute<float>(param);
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionConvBNAddReluParam<CPU>>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvBNAddReluKernel<CPU, float>;
......
......@@ -60,25 +60,15 @@ template <>
void ConvBNReluKernel<CPU, float>::Compute(
const FusionConvBNReluParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
......
......@@ -44,29 +44,25 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
#endif // __aarch64__
} else {
if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 0 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
} else if (depth3x3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
param->Strides()[0] == 2) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT;
#ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
} else if (conv3x3 && !depth3x3 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 /* &&
param->Output()->dims()[1] >= 16 &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1
#if 0
&& param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */) {
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */
#endif
) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
param->transformed_filter_ = new framework::LoDTensor;
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <iostream>
namespace paddle_mobile {
namespace operators {
......@@ -41,21 +43,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConv5x5<int8_t, int32_t>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
......
......@@ -60,25 +60,15 @@ template <>
void DWConvBNReluKernel<CPU, float>::Compute(
const FusionDWConvBNReluParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
math::DepthwiseConv3x3S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
math::DepthwiseConv3x3S2<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
......
......@@ -115,35 +115,6 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
}
}
template <typename P>
void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
param.Output()->mutable_data<float>();
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
// param.Output(), false);
if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, false);
} else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.Bias(), true, false);
}
} else {
ConvAddBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -115,31 +115,6 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) {
}
}
}
template <typename P>
void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvBNAddReluBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&out_slice);
} else if (data_dim == 3U) {
col2vol(col, dilations, strides, paddings, &out_slice);
}
......
......@@ -12,2066 +12,1042 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include "operators/math/depthwise_conv3x3.h"
#include <vector>
#if __ARM_NEON
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias) {
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int _kernel_size = 3;
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const float zero = 0;
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
const int filter_channel_stride = 9;
const float *input_ptr = input->data<float>();
const float *filter_ptr = filter->data<float>();
if (if_bias) {
math::expand_bias(*bias, 1, output->dims());
output->ShareDataWith(*bias);
}
float *output_ptr = output->mutable_data<float>();
const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr2;
int hstart, wstart, hend, wend;
float result;
for (int i = 0; i < batch_size; ++i) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
const float *input_data =
input_ptr + (i * output_channels + c) * input_channel_stride;
float *output_data =
output_ptr + (i * output_channels + c) * output_channel_stride;
filter1 = filter_ptr + c * filter_channel_stride;
filter2 = filter1 + 3;
filter3 = filter2 + 3;
for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, input_height);
wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
output_ptr2 = output_data + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) {
result = 0;
float fake_input[9] = {0};
if (hstart == 0 && wstart == 0) {
// 左上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k -
(3 - wend)];
}
}
}
} else if (hstart == 0 && wend == input_width) {
// 右上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height && wstart == 0) {
// 左下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - 1 - hstart && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k - (3 - wend)];
}
}
}
} else if (hend == input_height && wend == input_width) {
// 右下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1 &&
k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (hstart == 0) {
// 顶部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height) {
// 底部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (wstart == 0) {
// 左侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width +
(k - (3 - wend))];
}
}
}
} else if (wend == input_width) {
// 右侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
}
for (int l = 0; l < 9; ++l) {
result += fake_input[l] * filter1[l];
}
if (if_bias) {
output_data[ph * output_width + pw] += result;
} else {
output_data[ph * output_width + pw] = result;
}
} else {
#if __ARM_NEON
#if __aarch64__
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
const float32x4_t v_filter1 = vld1q_f32(filter1);
const float32x4_t v_filter2 = vld1q_f32(filter2);
const float32x4_t v_filter3 = vld1q_f32(filter3);
float32x4_t mula = vmulq_f32(data1, v_filter1);
mula = vmlaq_f32(mula, data2, v_filter2);
mula = vmlaq_f32(mula, data3, v_filter3);
float32x2_t res = vpadd_f32(
vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula));
res = vpadd_f32(res, res);
if (if_bias) {
output_data[ph * output_width + pw] += vget_lane_f32(res, 0);
} else {
output_data[ph * output_width + pw] = vget_lane_f32(res, 0);
}
#else
asm volatile(
"vld1.32 {q1}, [%[pos1]] \n\t"
"vld1.32 {q4}, [%[filter1]] \n\t"
"vmov.f32 q0, #0.0 \n\t"
"vld1.32 {q2}, [%[pos2]] \n\t"
"vld1.32 {q5}, [%[filter2]] \n\t"
"vmla.f32 q0, q1, q4 \n\t"
"vld1.32 {q3}, [%[pos3]] \n\t"
"vld1.32 {q6}, [%[filter3]] \n\t"
"vmla.f32 q0, q2, q5 \n\t"
"vmla.f32 q0, q3, q6 \n\t"
"vmov.f32 d1[1], %[zero] \n\t"
"vadd.f32 d4, d0, d1 \n\t"
"vadd.f32 s10, s8, s9 \n\t"
"vst1.32 {d5[0]},[%[output_ptr]] \n\t"
:
: [input_data] "r"(input_data), [pos1] "r"(pos1),
[pos2] "r"(pos2), [pos3] "r"(pos3), [filter1] "r"(filter1),
[filter2] "r"(filter2), [filter3] "r"(filter3),
[output_ptr] "r"(output_ptr2), [zero] "r"(zero)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
#endif // __aarch64__
#else
#endif // __ARM_NEON
}
}
}
}
}
#ifndef __aarch64__
inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) {
float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0));
float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1));
return vcombine_f32(sum0, sum1);
}
#endif
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu) {
#if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int hxw = h * w;
// const int l = h;
// leftTop, rightTop, leftBottom, rightBottom
const int lt = 0;
const int rt = w - 1;
const int lb = (h - 1) * w;
const int rb = h * w - 1;
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; ++b) {
#pragma omp parallel for
for (int j = 0; j < c; ++j) {
const float *filter_data_tmp = filter->data<float>() + j * 9;
const float *input_data = input->data<float>() + j * hxw;
float *output_data = output->mutable_data<float>() + j * hxw;
float32x4_t vbias;
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[lt] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[w] + w22 * input_data[w + 1];
output_data[rt] = w10 * input_data[w - 2] + w11 * input_data[w - 1] +
w20 * input_data[2 * w - 2] +
w21 * input_data[2 * w - 1];
output_data[lb] =
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] +
w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[rb] =
w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] +
w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1];
if (if_bias) {
output_data[lt] += bias_data[j];
output_data[rt] += bias_data[j];
output_data[lb] += bias_data[j];
output_data[rb] += bias_data[j];
}
if (if_relu) {
output_data[lt] = output_data[lt] < 0 ? 0 : output_data[lt];
output_data[rt] = output_data[rt] < 0 ? 0 : output_data[rt];
output_data[lb] = output_data[lb] < 0 ? 0 : output_data[lb];
output_data[rb] = output_data[rb] < 0 ? 0 : output_data[rb];
}
for (int i = 1; i < h - 1; ++i) {
int left = i * w;
int right = i * w + w - 1;
output_data[left] =
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] +
w11 * input_data[i * w] + w12 * input_data[i * w + 1] +
w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1];
output_data[right] = w00 * input_data[i * w + w - 1 - w - 1] +
w01 * input_data[i * w + w - 1 - w] +
w10 * input_data[i * w + w - 1 - 1] +
w11 * input_data[i * w + w - 1] +
w20 * input_data[i * w + w - 1 + w - 1] +
w21 * input_data[i * w + w - 1 + w];
if (if_bias) {
output_data[left] += bias_data[j];
output_data[right] += bias_data[j];
}
if (if_relu) {
output_data[left] = output_data[left] < 0 ? 0 : output_data[left];
output_data[right] = output_data[right] < 0 ? 0 : output_data[right];
}
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w);
const float *input_tmp_end = input_tmp + (h - 2) * w;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w);
int c_mid = w_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr + (h - 1) * w, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad
float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]);
float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2);
}
}
// mid
for (int i = 0; i < h - 2; ++i) {
auto output_ptr = output_data + (i + 1) * w + 1;
input_tmp = input_data + i * w;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + w);
auto in4_tmp = vld1q_f32(input_tmp + w + w);
c_mid = w_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + w + 4);
auto in5_tmp = vld1q_f32(input_tmp + w + w + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr, out0);
template <int Stride = 1>
inline void Depth3x3NormalRowLoadInput(const float *input, float32x4_t *y) {
y[0] = vld1q_f32(input);
y[2] = vld1q_f32(input + 4);
y[1] = vextq_f32(y[0], y[2], 1);
y[2] = vextq_f32(y[0], y[2], 2);
}
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
template <>
inline void Depth3x3NormalRowLoadInput<2>(const float *input, float32x4_t *y) {
float32x4x2_t x = vld2q_f32(input);
y[0] = x.val[0];
y[1] = x.val[1];
y[2] = vextq_f32(y[0], y[0], 1);
y[2] = vsetq_lane_f32(input[8], y[2], 3);
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
#define DEPTHWISE_CONV3X3_NORMAL_BORDER(start, end) \
for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride_w; \
const int w_in_end = w_in_start + 3; \
const int w_start = w_in_start > 0 ? w_in_start : 0; \
const int w_end = w_in_end < input_w ? w_in_end : input_w; \
float value = 0; \
for (int h_in = h_start; h_in < h_end; ++h_in) { \
for (int w_in = w_start; w_in < w_end; ++w_in) { \
value += filter[(h_in - h_in_start) * 3 + (w_in - w_in_start)] * \
input[h_in * input_w + w_in]; \
} \
} \
output_ptr[w] = value; \
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
template <int Stride_h, int Stride_w>
inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter,
const int h_output, const int input_h,
const int input_w, const int padding_h,
const int padding_w, const int output_w,
float *output, float32x4_t *ker) {
const int h_in_start = -padding_h + h_output * Stride_h;
const int h_in_end = h_in_start + 3;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h;
const int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
const int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1;
// const int valid_w_end = output_w - valid_w_start;
float *output_ptr = output + h_output * output_w;
// border left
DEPTHWISE_CONV3X3_NORMAL_BORDER(0, valid_w_start)
// middle
int output_tiles = (valid_w_end - valid_w_start) >> 2;
float32x4_t _sum, _x[3];
// valid w
for (int w = 0; w < output_tiles * 4; w += 4) {
_sum = vdupq_n_f32(0.f);
int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride_w - padding_w;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth3x3NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0);
}
vst1q_f32(output_ptr + output_offset, _sum);
}
#endif
// remain valid w
int remain = (valid_w_end - valid_w_start) & 0x3;
if (remain > 0) {
_sum = vdupq_n_f32(0.f);
int remain_start = valid_w_start + (output_tiles << 2);
int input_w_offset = remain_start * Stride_w - padding_w;
float *output_ptr0 = output_ptr + remain_start;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth3x3NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0);
}
switch (remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _sum, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_sum));
break;
case 1:
vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0);
break;
}
}
// border right
DEPTHWISE_CONV3X3_NORMAL_BORDER(valid_w_end, output_w)
}
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->mutable_data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int batch_size = static_cast<int>(input->dims()[0]);
const int input_channel = static_cast<int>(input->dims()[1]);
const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
const int hxw = input_height * input_width;
// const int l = input_height;
const int h = input_height;
const int w = input_width;
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; b++) {
#pragma omp parallel for
for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * hxw;
float *output_data = output->data<float>() + c * hxw;
float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]);
float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]);
float w00 = filter_data[0];
float w01 = filter_data[1];
float w02 = filter_data[2];
float w10 = filter_data[3];
float w11 = filter_data[4];
float w12 = filter_data[5];
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
for (int i = 1; i < output_height - 1; i++) {
float *output_ptr;
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, tmp4,
tmp5, out0;
for (int m = 1; m < output_width - 4; m += 4) {
output_ptr = output_data + i * output_width + m;
in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1);
in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3);
in2 = vld1q_f32(input_data + i * input_width + m - 1);
in3 = vld1q_f32(input_data + i * input_width + m + 3);
in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1);
in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
}
int m;
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
template <>
void DepthwiseConv3x3S1<float, float>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {
const float *input_data = input.data<float>();
const float *filter_data = filter.data<float>();
float *out_data = output->mutable_data<float>();
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = padding_h;
int valid_h_end = output_h - valid_h_start;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start;
#pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) {
const float *input_ptr = input_data + g * image_size;
const float *filter_ptr = filter_data + g * 9;
float *output_ptr = out_data + g * out_image_size;
const float *filter_ptr0 = filter_ptr;
const float *filter_ptr1 = filter_ptr0 + 3;
const float *filter_ptr2 = filter_ptr1 + 3;
float32x4_t _ker[3];
_ker[0] = vld1q_f32(filter_ptr0);
_ker[1] = vld1q_f32(filter_ptr1);
_ker[2] = vld1q_f32(filter_ptr2);
// pad top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
}
for (int j = m; j < output_width - 1; j++) {
output_data[i * output_width + j] =
input_data[(i - 1) * input_width + j - 1] * w00 +
input_data[(i - 1) * input_width + j] * w01 +
input_data[(i - 1) * input_width + j + 1] * w02 +
input_data[(i)*input_width + j - 1] * w10 +
input_data[(i)*input_width + j] * w11 +
input_data[(i)*input_width + j + 1] * w12 +
input_data[(i + 1) * input_width + j - 1] * w20 +
input_data[(i + 1) * input_width + j] * w21 +
input_data[(i + 1) * input_width + j + 1] * w22;
output_data[i * output_width + j] =
newscale_data[c] * output_data[i * output_width + j] +
newbias_data[c];
if (if_relu) {
output_data[i * output_width + j] =
output_data[i * output_width + j] < 0
? 0
: output_data[i * output_width + j];
// output 2x6
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
// pad left
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t zero = vdupq_n_f32(0.f);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
row3 = vextq_f32(zero, row3, 3);
float32x4_t acc0, acc1;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
} else {
acc0 = vmulq_f32(row0, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
acc0 = vextq_f32(acc0, acc0, 1);
acc1 = vmulq_f32(row1, _ker[0]);
acc1 = vmlaq_f32(acc1, row2, _ker[1]);
acc1 = vmlaq_f32(acc1, row3, _ker[2]);
acc1 = vextq_f32(acc1, acc1, 1);
float32x2_t sum = vpadd_f32(vget_low_f32(acc0), vget_low_f32(acc1));
vst1_lane_f32(output_ptr0 + w, sum, 0);
vst1_lane_f32(output_ptr1 + w, sum, 1);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
row3 = vextq_f32(zero, row3, 3);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[w] + w22 * input_data[w + 1];
output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] +
w20 * input_data[2 * w - 2] +
w21 * input_data[2 * w - 1];
output_data[(h - 1) * w] =
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] +
w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[h * w - 1] =
w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] +
w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1];
output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c];
output_data[w - 1] =
output_data[w - 1] * newscale_data[c] + newbias_data[c];
output_data[(h - 1) * w] =
output_data[(h - 1) * w] * newscale_data[c] + newbias_data[c];
output_data[h * w - 1] =
output_data[h * w - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w - 1];
output_data[(h - 1) * w] =
output_data[(h - 1) * w] < 0 ? 0 : output_data[(h - 1) * w];
output_data[h * w - 1] =
output_data[h * w - 1] < 0 ? 0 : output_data[h * w - 1];
}
for (int i = 1; i < h - 1; ++i) {
output_data[i * w] =
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] +
w11 * input_data[i * w] + w12 * input_data[i * w + 1] +
w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1];
output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] +
w01 * input_data[i * w + w - 1 - w] +
w10 * input_data[i * w + w - 1 - 1] +
w11 * input_data[i * w + w - 1] +
w20 * input_data[i * w + w - 1 + w - 1] +
w21 * input_data[i * w + w - 1 + w];
output_data[i * w] =
output_data[i * w] * newscale_data[c] + newbias_data[c];
output_data[i * w + w - 1] =
output_data[i * w + w - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i * w];
output_data[i * w + w - 1] =
output_data[i * w + w - 1] < 0 ? 0 : output_data[i * w + w - 1];
}
// valid
float32x4_t _result0, _result1, _result2, _result3;
for (int loop = 0; loop < output_w_tiles; ++loop) {
float32x4_t _row00 = vld1q_f32(input_ptr0);
float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
float32x4_t _row10 = vld1q_f32(input_ptr1);
float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
_result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
_result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0);
_result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0);
_row00 = vld1q_f32(input_ptr2);
_row01 = vld1q_f32(input_ptr2 + 4);
_row10 = vld1q_f32(input_ptr3);
_row11 = vld1q_f32(input_ptr3 + 4);
_ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
_ext03 = vextq_f32(_row01, _row01, 1);
_ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
vst1_f32(output_ptr0 + 4, vget_low_f32(_result1));
vst1q_f32(output_ptr1, _result2);
vst1_f32(output_ptr1 + 4, vget_low_f32(_result3));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
input_ptr3 += 6;
output_ptr0 += 6;
output_ptr1 += 6;
}
int m;
for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr = output_data + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + m - 1);
in1 = vld1q_f32(input_data + m + 3);
in2 = vld1q_f32(input_data + input_width + m - 1);
in3 = vld1q_f32(input_data + input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
// remain w
if (output_w_remain > 0) {
float32x4_t _row00 = vld1q_f32(input_ptr0);
float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
float32x4_t _row10 = vld1q_f32(input_ptr1);
float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
_result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
_result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0);
_result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0);
_row00 = vld1q_f32(input_ptr2);
_row01 = vld1q_f32(input_ptr2 + 4);
_row10 = vld1q_f32(input_ptr3);
_row11 = vld1q_f32(input_ptr3 + 4);
_ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
_ext03 = vextq_f32(_row01, _row01, 1);
_ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0);
_result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1);
_result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0);
_result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1);
_result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 5:
vst1q_lane_f32(output_ptr0 + 4, _result1, 0);
vst1q_lane_f32(output_ptr1 + 4, _result3, 0);
case 4:
vst1q_f32(output_ptr0, _result0);
vst1q_f32(output_ptr1, _result2);
break;
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
vst1q_lane_f32(output_ptr1 + 2, _result2, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
vst1_f32(output_ptr1, vget_low_f32(_result2));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
vst1q_lane_f32(output_ptr1, _result2, 0);
break;
}
vst1q_f32(output_ptr, out0);
}
for (m = 1; (m + 3) < output_width - 1; m += 4) {
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
input_ptr3 += output_w_remain;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
}
for (int j = m; j < output_width - 1; j++) {
output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 +
input_data[j + 1] * w12 +
input_data[input_width + j - 1] * w20 +
input_data[input_width + j] * w21 +
input_data[input_width + j + 1] * w22;
output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
// pad right
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t zero = vdup_n_f32(0.f);
float32x2_t acc0, acc1;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
} else {
acc0 = vmul_f32(row0, vget_low_f32(_ker[0]));
acc0 = vmla_f32(acc0, row1, vget_low_f32(_ker[1]));
acc0 = vmla_f32(acc0, row2, vget_low_f32(_ker[2]));
acc1 = vmul_f32(row1, vget_low_f32(_ker[0]));
acc1 = vmla_f32(acc1, row2, vget_low_f32(_ker[1]));
acc1 = vmla_f32(acc1, row3, vget_low_f32(_ker[2]));
float32x2_t sum = vpadd_f32(acc0, acc1);
vst1_lane_f32(output_ptr0, sum, 0);
vst1_lane_f32(output_ptr1, sum, 1);
row0 = vext_f32(row0, zero, 1);
row1 = vext_f32(row1, zero, 1);
row2 = vext_f32(row2, zero, 1);
row3 = vext_f32(row3, zero, 1);
}
output_ptr0++;
output_ptr1++;
}
}
for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr =
output_data + (output_height - 1) * output_width + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1);
in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3);
in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1);
in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + start_h * output_w;
// pad left
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t zero = vdupq_n_f32(0.f);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
float32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0.f;
} else {
acc = vmulq_f32(row0, _ker[0]);
acc = vmlaq_f32(acc, row1, _ker[1]);
acc = vmlaq_f32(acc, row2, _ker[2]);
acc = vextq_f32(acc, acc, 1);
float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_low_f32(acc));
vst1_lane_f32(output_ptr0 + w, sum, 0);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
}
}
vst1q_f32(output_ptr, out0);
output_ptr0 += valid_w_start;
}
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
// valid
float32x4_t _result0, _result1;
for (int loop = 0; loop < output_w_tiles; ++loop) {
float32x4_t _row00 = vld1q_f32(input_ptr0);
float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
float32x4_t _row10 = vld1q_f32(input_ptr1);
float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
_result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
_row00 = vld1q_f32(input_ptr2);
_row01 = vld1q_f32(input_ptr2 + 4);
_ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
_ext03 = vextq_f32(_row01, _row01, 1);
_ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
vst1_f32(output_ptr0 + 4, vget_low_f32(_result1));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
output_ptr0 += 6;
}
for (int j = m; j < output_width - 1; j++) {
output_data[(output_height - 1) * input_width + j] =
input_data[(output_height - 2) * input_width + j - 1] * w00 +
input_data[(output_height - 2) * input_width + j] * w01 +
input_data[(output_height - 2) * input_width + j + 1] * w02 +
input_data[(output_height - 1) * input_width + j - 1] * w10 +
input_data[(output_height - 1) * input_width + j] * w11 +
input_data[(output_height - 1) * input_width + j + 1] * w12;
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] *
newscale_data[c] +
newbias_data[c];
if (if_relu) {
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] < 0
? 0
: output_data[(output_height - 1) * output_width + j];
}
}
}
}
/*
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
// const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[w] + w22 * input_data[w + 1];
output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w -
1] + w20 * input_data[2 * w - 2] + w21 * input_data[2 * w - 1];
output_data[(h - 1) * w] =
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w +
1] + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[h * w - 1] = w00 * input_data[h*w-w-2] +
w01 * input_data[h*w-w-1] +
w10 * input_data[h * w - 2] +
w11 * input_data[h * w - 1];
output_data[0] = output_data[0] * newscale_data[j] +
newbias_data[j]; output_data[w - 1] = output_data[w - 1] *
newscale_data[j] + newbias_data[j]; output_data[(h - 1) * w] =
output_data[(h - 1) * w] * newscale_data[j] + newbias_data[j];
output_data[h * w - 1] =
output_data[h * w - 1] * newscale_data[j] + newbias_data[j];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w -
1]; output_data[(h - 1) * w] = output_data[(h - 1) * w] < 0 ? 0 :
output_data[(h - 1) * w]; output_data[h * w - 1] = output_data[h * w - 1]
< 0 ? 0 : output_data[h * w - 1];
}
for (int i = 1; i < h - 1; ++i) {
output_data[i * w] =
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1]
+ w11 * input_data[i * w] + w12 * input_data[i * w + 1] + w21 *
input_data[i * w + w] + w22 * input_data[i * w + w + 1]; output_data[i *
w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + w01 * input_data[i
* w + w - 1 - w] + w10 * input_data[i * w + w - 1 - 1] + w11 *
input_data[i * w + w - 1] + w20 * input_data[i * w + w - 1 + w - 1] + w21
* input_data[i * w + w - 1 + w]; output_data[i * w] = output_data[i * w]
* newscale_data[j] + newbias_data[j]; output_data[i * w + w - 1] =
output_data[i * w + w - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i
* w]; output_data[i * w + w - 1] = output_data[i * w + w - 1] < 0 ? 0 :
output_data[i * w + w - 1];
}
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1,
tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); in2 =
vld1q_f32(input_tmp + w); const float *input_tmp_end = input_tmp + (h -
2) * w; in4 = vld1q_f32(input_tmp_end); in6 = vld1q_f32(input_tmp_end +
w); int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid >
3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); in3 =
vld1q_f32(input_tmp + w + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr + (h - 1) * w, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad
float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]);
float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2);
}
}
// mid
for (int i = 0; i < h - 2; ++i) {
auto output_ptr = output_data + (i + 1) * w + 1;
input_tmp = input_data + i * w;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + w);
auto in4_tmp = vld1q_f32(input_tmp + w + w);
c_mid = w_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + w + 4);
auto in5_tmp = vld1q_f32(input_tmp + w + w + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
}
if (output_w_remain > 0) {
float32x4_t _row00 = vld1q_f32(input_ptr0);
float32x4_t _row01 = vld1q_f32(input_ptr0 + 4);
float32x4_t _row10 = vld1q_f32(input_ptr1);
float32x4_t _row11 = vld1q_f32(input_ptr1 + 4);
float32x4_t _ext01 = vextq_f32(_row00, _row01, 1);
float32x4_t _ext02 = vextq_f32(_row00, _row01, 2);
float32x4_t _ext03 = vextq_f32(_row01, _row01, 1);
float32x4_t _ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0);
_result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0);
_ext01 = vextq_f32(_row10, _row11, 1);
_ext02 = vextq_f32(_row10, _row11, 2);
_ext03 = vextq_f32(_row11, _row11, 1);
_ext04 = vextq_f32(_row11, _row11, 2);
_result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0);
_row00 = vld1q_f32(input_ptr2);
_row01 = vld1q_f32(input_ptr2 + 4);
_ext01 = vextq_f32(_row00, _row01, 1);
_ext02 = vextq_f32(_row00, _row01, 2);
_ext03 = vextq_f32(_row01, _row01, 1);
_ext04 = vextq_f32(_row01, _row01, 2);
_result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0);
_result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0);
_result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 5:
vst1q_lane_f32(output_ptr0 + 4, _result1, 0);
case 4:
vst1q_f32(output_ptr0, _result0);
break;
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
break;
}
*/
#endif
}
/// w!=h not fix
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int _kernel_size = 3;
const int stride_height = 2;
const int stride_width = 2;
const int padding_height = 1;
const int padding_width = 1;
const float zero = 0;
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
const int filter_channel_stride = 9;
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->mutable_data<float>();
const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride;
const int filter_batch_stride = output_channels * output_channel_stride;
const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr;
int hstart, wstart, hend, wend;
float result;
for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < output_channels; ++c) {
filter1 = filter_data;
filter2 = filter1 + 3;
filter3 = filter2 + 3;
for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, input_height);
wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
output_ptr = output_data + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) {
result = 0;
float fake_input[9] = {0};
if (hstart == 0 && wstart == 0) {
// 左上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k -
(3 - wend)];
}
}
}
} else if (hstart == 0 && wend == input_width) {
// 右上角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend && k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height && wstart == 0) {
// 左下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - 1 - hstart && k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k - (3 - wend)];
}
}
}
} else if (hend == input_height && wend == input_width) {
// 右下角
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1 &&
k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (hstart == 0) {
// 顶部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j >= 3 - hend) {
fake_input[3 * j + k] =
input_data[(j - (3 - hend)) * input_width + k + wstart];
}
}
}
} else if (hend == input_height) {
// 底部
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (j <= input_height - hstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
} else if (wstart == 0) {
// 左侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k >= 3 - wend) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width +
(k - (3 - wend))];
}
}
}
} else if (wend == input_width) {
// 右侧
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
if (k <= input_width - wstart - 1) {
fake_input[3 * j + k] =
input_data[(j + hstart) * input_width + k + wstart];
}
}
}
}
for (int l = 0; l < 9; ++l) {
result += fake_input[l] * filter1[l];
}
output_data[ph * output_width + pw] =
newscale_data[c] * result + newbias_data[c];
if (if_relu) {
output_data[ph * output_width + pw] =
output_data[ph * output_width + pw] < 0
? 0
: output_data[ph * output_width + pw];
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
output_ptr0 += output_w_remain;
}
// pad right
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t zero = vdup_n_f32(0.f);
float32x2_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
} else {
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
const float32x4_t v_filter1 = vld1q_f32(filter1);
const float32x4_t v_filter2 = vld1q_f32(filter2);
const float32x4_t v_filter3 = vld1q_f32(filter3);
float32x4_t mula = vmulq_f32(data1, v_filter1);
mula = vmlaq_f32(mula, data2, v_filter2);
mula = vmlaq_f32(mula, data3, v_filter3);
float32x2_t res = vpadd_f32(
vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula));
res = vpadd_f32(res, res);
output_data[ph * output_width + pw] =
vget_lane_f32(res, 0) * newscale_data[c] + newbias_data[c];
if (if_relu) {
output_data[ph * output_width + pw] =
output_data[ph * output_width + pw] < 0
? 0
: output_data[ph * output_width + pw];
}
acc = vmul_f32(row0, vget_low_f32(_ker[0]));
acc = vmla_f32(acc, row1, vget_low_f32(_ker[1]));
acc = vmla_f32(acc, row2, vget_low_f32(_ker[2]));
float32x2_t sum = vpadd_f32(acc, acc);
vst1_lane_f32(output_ptr0, sum, 0);
row0 = vext_f32(row0, zero, 1);
row1 = vext_f32(row1, zero, 1);
row2 = vext_f32(row2, zero, 1);
}
output_ptr0++;
}
}
input_data += input_channel_stride;
output_data += output_channel_stride;
filter_data += filter_channel_stride;
}
input_data += input_batch_stride;
output_data += output_batch_stride;
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
}
}
#endif
}
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu) {
#if __ARM_NEON
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->mutable_data<float>();
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]);
const int out_l = out_h;
const int in_l = in_h;
const int inhxw = in_h * in_w;
const int outhxw = out_h * out_w;
/// todo : fix if_pad when w != h
const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0;
const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const float *input_row_ptr;
float *output_row_ptr;
const int w_times = (out_w - 2) / 3;
float32x4_t vbias = vdupq_n_f32(0.0);
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
auto output_data_tmp = output_data + j * out_h * out_w;
auto input_data_tmp = input_data + j * in_h * in_w;
auto input_const = input_data_tmp;
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
int h_mid = 0;
for (; h_mid < out_h - 1; h_mid++) {
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
template <>
void DepthwiseConv3x3S2<float, float>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {
const float *input_data = input.data<float>();
const float *filter_data = filter.data<float>();
float *out_data = output->mutable_data<float>();
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = (padding_h + 1) / 2;
int valid_h_end = (input_h + padding_h - 1) / 2;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = (padding_w + 1) / 2;
int valid_w_end = (input_w + padding_w - 1) / 2;
int valid_w = valid_w_end - valid_w_start;
int input_w_start = 2 * valid_w_start - padding_w;
#pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) {
const float *input_ptr = input_data + g * image_size;
const float *filter_ptr = filter_data + g * 9;
float *output_ptr = out_data + g * out_image_size;
const float *filter_ptr0 = filter_ptr;
const float *filter_ptr1 = filter_ptr0 + 3;
const float *filter_ptr2 = filter_ptr1 + 3;
float32x4_t _ker[3];
_ker[0] = vld1q_f32(filter_ptr0);
_ker[1] = vld1q_f32(filter_ptr1);
_ker[2] = vld1q_f32(filter_ptr2);
// pad top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
}
// valid 2x4
int output_w_tiles = valid_w / 4;
int output_w_remain = valid_w - output_w_tiles * 4;
for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0;
output_ptr1[w] = 0;
} else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
}
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vaddq_f32(res3, vbias);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
vst1q_f32(output_row_ptr, res3);
input_row_ptr += 6;
output_row_ptr += 3;
}
}
clock();
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
if (!if_pad_b) {
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
}
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vaddq_f32(res3, vbias);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
if ((w4 != w_times)) {
vst1q_f32(output_row_ptr, res3);
} else {
if (out_w - 2 - w_times * 3 == 1) {
vst1q_lane_f32(output_row_ptr, res3, 0);
} else if (out_w - 2 - w_times * 3 == 2) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
float32x4_t row0 = vld1q_f32(input_ptr0 - padding);
float32x4_t row1 = vld1q_f32(input_ptr1 - padding);
float32x4_t row2 = vld1q_f32(input_ptr2 - padding);
float32x4_t row3 = vld1q_f32(input_ptr3 - padding);
float32x4_t row4 = vld1q_f32(input_ptr4 - padding);
float32x4_t acc0 = vmulq_f32(row0, _ker[0]);
float32x4_t acc1 = vmulq_f32(row2, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc1 = vmlaq_f32(acc1, row3, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
acc1 = vmlaq_f32(acc1, row4, _ker[2]);
float sum0 = vgetq_lane_f32(acc0, 2);
float sum1 = vgetq_lane_f32(acc1, 2);
if (padding == 1) {
sum0 += vgetq_lane_f32(acc0, 1);
sum1 += vgetq_lane_f32(acc1, 1);
}
output_ptr0[w] = sum0;
output_ptr1[w] = sum1;
}
}
input_row_ptr += 6;
output_row_ptr += 3;
input_ptr0 += input_w_start;
input_ptr1 += input_w_start;
input_ptr2 += input_w_start;
input_ptr3 += input_w_start;
input_ptr4 += input_w_start;
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
// leftTop, rightTop, leftBottom, rightBottom
int lt = 0;
int rt = out_w - 1;
int lb = out_w * (out_h - 1);
int rb = out_h * out_w - 1;
output_data_tmp[lt] = input_const[0] * w11 + input_const[1] * w12 +
input_const[in_w] * w21 +
input_const[in_w + 1] * w22;
out2in_mid = (out_w - 1) * 2;
output_data_tmp[rt] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_h - 1) * 2 * in_w;
output_data_tmp[lb] =
w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] +
(1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2;
output_data_tmp[rb] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
(1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w]) +
(1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1]) +
(1 - if_pad_r) * (1 - if_pad_b) * w22 *
input_const[out2in_mid + in_w + 1];
if (if_bias) {
output_data_tmp[lt] += bias_data[j];
output_data_tmp[rt] += bias_data[j];
output_data_tmp[lb] += bias_data[j];
output_data_tmp[rb] += bias_data[j];
// valid
float32x4_t _result0, _result1, _ext;
for (int loop = 0; loop < output_w_tiles; ++loop) {
float32x4x2_t _row0 = vld2q_f32(input_ptr0);
float32x4x2_t _row1 = vld2q_f32(input_ptr1);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
_result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
_row0 = vld2q_f32(input_ptr2);
_row1 = vld2q_f32(input_ptr3);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
_result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr3[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0);
_row0 = vld2q_f32(input_ptr4);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr4[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
vst1q_f32(output_ptr1, _result1);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
input_ptr3 += 8;
input_ptr4 += 8;
output_ptr0 += 4;
output_ptr1 += 4;
}
if (if_relu) {
output_data_tmp[lt] = output_data_tmp[lt] < 0 ? 0 : output_data_tmp[lt];
output_data_tmp[rt] = output_data_tmp[rt] < 0 ? 0 : output_data_tmp[rt];
output_data_tmp[lb] = output_data_tmp[lb] < 0 ? 0 : output_data_tmp[lb];
output_data_tmp[rb] = output_data_tmp[rb] < 0 ? 0 : output_data_tmp[rb];
}
for (int i = 1; i < out_h - 1; i++) {
out2in_mid = i * 2 * in_w;
int left = i * out_w;
output_data_tmp[left] = w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] +
w12 * input_const[out2in_mid + 1] +
w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1];
out2in_mid = i * 2 * in_w + (out_w - 1) * 2;
int right = i * out_w + out_w - 1;
output_data_tmp[right] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
if (if_bias) {
output_data_tmp[left] += bias_data[j];
output_data_tmp[right] += bias_data[j];
}
if (if_relu) {
output_data_tmp[left] =
output_data_tmp[left] < 0 ? 0 : output_data_tmp[left];
output_data_tmp[right] =
output_data_tmp[right] < 0 ? 0 : output_data_tmp[right];
// remain w
if (output_w_remain > 0) {
float32x4x2_t _row0 = vld2q_f32(input_ptr0);
float32x4x2_t _row1 = vld2q_f32(input_ptr1);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
_result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
_row0 = vld2q_f32(input_ptr2);
_row1 = vld2q_f32(input_ptr3);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
_result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr3[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0);
_row0 = vld2q_f32(input_ptr4);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr4[8], _ext, 3);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0);
_result1 =
vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1);
_result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
vst1q_lane_f32(output_ptr1 + 2, _result1, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
vst1_f32(output_ptr1, vget_low_f32(_result1));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
vst1q_lane_f32(output_ptr1, _result1, 0);
break;
}
input_ptr0 += output_w_remain * 2;
input_ptr1 += output_w_remain * 2;
input_ptr2 += output_w_remain * 2;
input_ptr3 += output_w_remain * 2;
input_ptr4 += output_w_remain * 2;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
}
filter_data_tmp += 9;
}
input_data += inhxw * c;
output_data += outhxw * c;
}
#endif
}
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON
// #ifdef _OPENMP
// const float *newscale_data = new_scale->data<float>();
// const float *newbias_data = new_bias->data<float>();
//
// const int batch_size = static_cast<int>(input->dims()[0]);
// const int input_channel = static_cast<int>(input->dims()[1]);
//
// const int input_height = static_cast<int>(input->dims()[2]);
// const int input_width = static_cast<int>(input->dims()[3]);
// const int output_height = static_cast<int>(output->dims()[2]);
// const int output_width = static_cast<int>(output->dims()[3]);
// const int inhxw = input_height * input_width;
// const int outhxw = output_height * output_width;
//
// float32x4_t zero = vdupq_n_f32(0.0);
// for (int b = 0; b < batch_size; b++) {
// #pragma omp parallel for
// for (int c = 0; c < input_channel; c++) {
// const float *filter_data = filter->data<float>() + c * 9;
// const float *input_data = input->data<float>() + c * inhxw;
// float *output_data = output->data<float>() + c * outhxw;
// float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]);
// float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]);
//
// float w00 = filter_data[0];
// float w01 = filter_data[1];
// float w02 = filter_data[2];
// float w10 = filter_data[3];
// float w11 = filter_data[4];
// float w12 = filter_data[5];
// float w20 = filter_data[6];
// float w21 = filter_data[7];
// float w22 = filter_data[8];
//
// int m;
// for (m = 1; m < output_width - 2; m = m + 3) {
// float *output_ptr = output_data + m;
// float32x4x2_t input_buff_mid{}, input_buff_bottom{};
// float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
// input_buff_mid = vld2q_f32(input_data + (2 * m - 1));
// input_buff_bottom = vld2q_f32(input_data + input_width + (2 * m -
// 1));
//
// in0 = input_buff_mid.val[0];
// tmp0 = input_buff_mid.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
//
// in2 = input_buff_bottom.val[0];
// tmp2 = input_buff_bottom.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
//
// out0 = vmulq_n_f32(in0, w10);
// out0 = vmlaq_n_f32(out0, tmp0, w11);
// out0 = vmlaq_n_f32(out0, tmp1, w12);
// out0 = vmlaq_n_f32(out0, in2, w20);
// out0 = vmlaq_n_f32(out0, tmp2, w21);
// out0 = vmlaq_n_f32(out0, tmp3, w22);
// out0 = vmlaq_f32(vnewbias, vnewscale, out0);
// if (if_relu) {
// out0 = vmaxq_f32(out0, zero);
// }
// vst1q_lane_f32(output_ptr, out0, 0);
// vst1q_lane_f32(output_ptr + 1, out0, 1);
// vst1q_lane_f32(output_ptr + 2, out0, 2);
// }
// for (m = 1; m < output_width - 2; m += 3) {
// }
// for (int j = m; j < output_width; j++) {
// output_data[j] = input_data[2 * j - 1] * w10 + input_data[2 * j] *
// w11 +
// input_data[2 * j + 1] * w12 +
// input_data[2 * j - 1 + input_width] * w20 +
// input_data[2 * j + input_width] * w21 +
// input_data[2 * j + 1 + input_width] * w22;
// output_data[j] = newscale_data[c] * output_data[j] +
// newbias_data[c]; if (if_relu) {
// output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
// }
// }
//
// for (int i = 1; i < output_height; i += 1) {
// for (int m = 1; m < output_width - 2; m += 3) {
// float *output_ptr = output_data + i * output_width + m;
// float32x4x2_t input_buff_top{}, input_buff_mid{},
// input_buff_bottom{}; float32x4_t in0, in1, in2, in3, in4, in5,
// tmp0, tmp1, tmp2, tmp3,
// tmp4, tmp5, out0;
// input_buff_top =
// vld2q_f32(input_data + (2 * i - 1) * input_width + (2 * m -
// 1));
// input_buff_mid =
// vld2q_f32(input_data + (2 * i) * input_width + (2 * m - 1));
// input_buff_bottom =
// vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m -
// 1));
//
// in0 = input_buff_top.val[0];
// tmp0 = input_buff_top.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
//
// in2 = input_buff_mid.val[0];
// tmp2 = input_buff_mid.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
//
// in4 = input_buff_bottom.val[0];
// tmp4 = input_buff_bottom.val[1];
// tmp5 = vextq_f32(in4, zero, 1);
//
// out0 = vmulq_n_f32(in0, w00);
// out0 = vmlaq_n_f32(out0, tmp0, w01);
// out0 = vmlaq_n_f32(out0, tmp1, w02);
// out0 = vmlaq_n_f32(out0, in2, w10);
// out0 = vmlaq_n_f32(out0, tmp2, w11);
// out0 = vmlaq_n_f32(out0, tmp3, w12);
// out0 = vmlaq_n_f32(out0, in4, w20);
// out0 = vmlaq_n_f32(out0, tmp4, w21);
// out0 = vmlaq_n_f32(out0, tmp5, w22);
// out0 = vmlaq_f32(vnewbias, vnewscale, out0);
// if (if_relu) {
// out0 = vmaxq_f32(out0, zero);
// }
// vst1q_lane_f32(output_ptr, out0, 0);
// vst1q_lane_f32(output_ptr + 1, out0, 1);
// vst1q_lane_f32(output_ptr + 2, out0, 2);
// }
// int m;
// for (m = 1; m < output_width - 2; m += 3) {
// }
// for (int j = m; j < output_width; j++) {
// output_data[i * output_width + j] =
// input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 +
// input_data[(2 * i - 1) * input_width + 2 * j] * w01 +
// input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 +
// input_data[(2 * i) * input_width + 2 * j - 1] * w10 +
// input_data[(2 * i) * input_width + 2 * j] * w11 +
// input_data[(2 * i) * input_width + 2 * j + 1] * w12 +
// input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 +
// input_data[(2 * i + 1) * input_width + 2 * j] * w21 +
// input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22;
// output_data[i * output_width + j] =
// newscale_data[c] * output_data[i * output_width + j] +
// newbias_data[c];
// if (if_relu) {
// output_data[i * output_width + j] =
// output_data[i * output_width + j] < 0
// ? 0
// : output_data[i * output_width + j];
// }
// }
// }
// output_data[0] = input_data[0] * w11 + input_data[1] * w12 +
// input_data[input_height] * w21 +
// input_data[input_height + 1] * w22;
//
// output_data[0] = newscale_data[c] * output_data[0] + newbias_data[c];
// if (if_relu) {
// output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
// }
// for (int i = 1; i < output_height; i++) {
// output_data[i * output_width] =
// input_data[(2 * i - 1) * input_width] * w01 +
// input_data[(2 * i - 1) * input_width + 1] * w02 +
// input_data[(2 * i) * input_width] * w11 +
// input_data[(2 * i) * input_width + 1] * w12 +
// input_data[(2 * i + 1) * input_width] * w21 +
// input_data[(2 * i + 1) * input_width + 1] * w22;
//
// output_data[i * output_width] =
// newscale_data[c] * output_data[i * output_width] +
// newbias_data[c];
// if (if_relu) {
// output_data[i * output_width] = output_data[i * output_width] < 0
// ? 0
// : output_data[i *
// output_width];
// }
// }
// }
// }
//
// #else
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->mutable_data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]);
// const int out_l = out_h;
// const int in_l = in_h;
const int inhxw = in_h * in_w;
const int outhxw = out_h * out_w;
/// todo : fix if_pad when w != h
const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0;
const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int w_times = (out_w - 2) / 3;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
#pragma omp parallel for
for (int j = 0; j < c; j++) {
const float *input_row_ptr;
float *output_row_ptr;
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
auto output_data_tmp = output_data + j * out_h * out_w;
auto input_data_tmp = input_data + j * in_h * in_w;
auto input_const = input_data_tmp;
const float *filter_data_tmp = filter_data + 9 * j;
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
int h_mid = 0;
for (; h_mid < out_h - 1; h_mid++) {
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
// pad right
if (padding_w > 0) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t acc0, acc1;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
*output_ptr1 = 0;
} else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
}
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
acc0 = vmulq_f32(row0, _ker[0]);
acc1 = vmulq_f32(row2, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc1 = vmlaq_f32(acc1, row3, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
acc1 = vmlaq_f32(acc1, row4, _ker[2]);
float sum0 = vgetq_lane_f32(acc0, 0);
float sum1 = vgetq_lane_f32(acc1, 0);
if (padding == 1) {
sum0 += vgetq_lane_f32(acc0, 1);
sum1 += vgetq_lane_f32(acc1, 1);
}
*output_ptr0 = sum0;
*output_ptr1 = sum1;
}
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
vst1q_lane_f32(output_row_ptr + 2, res3, 2);
input_row_ptr += 6;
output_row_ptr += 3;
output_ptr0++;
output_ptr1++;
}
}
clock();
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11);
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10);
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
if (!if_pad_b) {
elewise_res1 =
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
}
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
if (if_relu) {
res3 = vmaxq_f32(res3, zero);
}
if ((w4 != w_times)) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
vst1q_lane_f32(output_row_ptr + 2, res3, 2);
} else {
if (out_w - 2 - w_times * 3 == 1) {
vst1q_lane_f32(output_row_ptr, res3, 0);
} else if (out_w - 2 - w_times * 3 == 2) {
vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (2 * start_h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + start_h * output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0;
} else {
float32x4_t row0 = vld1q_f32(input_ptr0 - padding);
float32x4_t row1 = vld1q_f32(input_ptr1 - padding);
float32x4_t row2 = vld1q_f32(input_ptr2 - padding);
float32x4_t acc0 = vmulq_f32(row0, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
float sum0 = vgetq_lane_f32(acc0, 2);
if (padding == 1) {
sum0 += vgetq_lane_f32(acc0, 1);
}
output_ptr0[w] = sum0;
}
}
input_row_ptr += 6;
output_row_ptr += 3;
input_ptr0 += input_w_start;
input_ptr1 += input_w_start;
input_ptr2 += input_w_start;
output_ptr0 += valid_w_start;
}
output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 +
input_const[in_w] * w21 +
input_const[in_w + 1] * w22;
out2in_mid = (out_w - 1) * 2;
output_data_tmp[out_w - 1] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_h - 1) * 2 * in_w;
output_data_tmp[out_w * (out_h - 1)] =
w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] +
(1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2;
output_data_tmp[out_h * out_w - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
(1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w]) +
(1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1]) +
(1 - if_pad_r) * (1 - if_pad_b) * w22 *
input_const[out2in_mid + in_w + 1];
output_data_tmp[0] =
output_data_tmp[0] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_w - 1] =
output_data_tmp[out_w - 1] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_w * (out_h - 1)] =
output_data_tmp[out_w * (out_h - 1)] * newscale_data[j] +
newbias_data[j];
output_data_tmp[out_h * out_w - 1] =
output_data_tmp[out_h * out_w - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0];
output_data_tmp[out_w - 1] =
output_data_tmp[out_w - 1] < 0 ? 0 : output_data_tmp[out_w - 1];
output_data_tmp[out_w * (out_h - 1)] =
output_data_tmp[out_w * (out_h - 1)] < 0
? 0
: output_data_tmp[out_w * (out_h - 1)];
output_data_tmp[out_h * out_w - 1] =
output_data_tmp[out_h * out_w - 1] < 0
? 0
: output_data_tmp[out_h * out_w - 1];
// valid
float32x4_t _result0, _ext;
for (int loop = 0; loop < output_w_tiles; ++loop) {
float32x4x2_t _row0 = vld2q_f32(input_ptr0);
float32x4x2_t _row1 = vld2q_f32(input_ptr1);
float32x4x2_t _row2 = vld2q_f32(input_ptr2);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
_result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
_ext = vextq_f32(_row2.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
vst1q_f32(output_ptr0, _result0);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
output_ptr0 += 4;
}
for (int i = 1; i < out_h - 1; i++) {
out2in_mid = i * 2 * in_w;
output_data_tmp[i * out_w] = w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] +
w12 * input_const[out2in_mid + 1] +
w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1];
out2in_mid = i * 2 * in_w + (out_w - 1) * 2;
output_data_tmp[i * out_w + out_w - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
output_data_tmp[i * out_w] =
output_data_tmp[i * out_w] * newscale_data[j] + newbias_data[j];
output_data_tmp[i * out_w + out_w - 1] =
output_data_tmp[i * out_w + out_w - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data_tmp[i * out_w] =
output_data_tmp[i * out_w] < 0 ? 0 : output_data_tmp[i * out_w];
output_data_tmp[i * out_w + out_w - 1] =
output_data_tmp[i * out_w + out_w - 1] < 0
? 0
: output_data_tmp[i * out_w + out_w - 1];
// remain w
if (output_w_remain > 0) {
float32x4x2_t _row0 = vld2q_f32(input_ptr0);
float32x4x2_t _row1 = vld2q_f32(input_ptr1);
float32x4x2_t _row2 = vld2q_f32(input_ptr2);
_ext = vextq_f32(_row0.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr0[8], _ext, 3);
_result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0);
_ext = vextq_f32(_row1.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr1[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0);
_ext = vextq_f32(_row2.val[0], _ext, 1);
_ext = vsetq_lane_f32(input_ptr2[8], _ext, 3);
_result0 =
vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0);
_result0 =
vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1);
_result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0);
switch (output_w_remain) {
case 3:
vst1q_lane_f32(output_ptr0 + 2, _result0, 2);
case 2:
vst1_f32(output_ptr0, vget_low_f32(_result0));
break;
case 1:
vst1q_lane_f32(output_ptr0, _result0, 0);
break;
}
input_ptr0 += output_w_remain * 2;
input_ptr1 += output_w_remain * 2;
input_ptr2 += output_w_remain * 2;
output_ptr0 += output_w_remain;
}
}
input_data += inhxw * c;
output_data += outhxw * c;
}
// #endif
#endif
}
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu) {
#if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]);
const int input_channel = static_cast<int>(input->dims()[1]);
const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
const int inhxw = input_height * input_width;
const int outhxw = output_height * output_width;
output->mutable_data<float>();
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; b++) {
#pragma omp parallel for
for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw;
const float *bias_data;
float32x4_t biasv;
if (if_bias) {
bias_data = bias->data<float>() + c;
biasv = vld1q_dup_f32(bias_data);
}
float *output_data = output->data<float>() + c * outhxw;
float w00 = filter_data[0];
float w01 = filter_data[1];
float w02 = filter_data[2];
float w10 = filter_data[3];
float w11 = filter_data[4];
float w12 = filter_data[5];
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
for (int i = 0; i < output_height; i += 1) {
for (int m = 0; m < output_width - 2; m += 3) {
float *output_ptr = output_data + i * output_width + m;
float32x4x2_t input_buff_top{}, input_buff_mid{}, input_buff_bottom{};
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3,
tmp4, tmp5, out0;
input_buff_top =
vld2q_f32(input_data + (2 * i) * input_width + (2 * m));
input_buff_mid =
vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m));
input_buff_bottom =
vld2q_f32(input_data + (2 * i + 2) * input_width + (2 * m));
in0 = input_buff_top.val[0];
tmp0 = input_buff_top.val[1];
tmp1 = vextq_f32(in0, zero, 1);
in2 = input_buff_mid.val[0];
tmp2 = input_buff_mid.val[1];
tmp3 = vextq_f32(in2, zero, 1);
in4 = input_buff_bottom.val[0];
tmp4 = input_buff_bottom.val[1];
tmp5 = vextq_f32(in4, zero, 1);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
if (if_bias) {
out0 = vaddq_f32(out0, biasv);
}
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_lane_f32(output_ptr, out0, 0);
vst1q_lane_f32(output_ptr + 1, out0, 1);
vst1q_lane_f32(output_ptr + 2, out0, 2);
}
int m;
for (m = 0; m < output_width - 2; m += 3) {
}
for (int j = m; j < output_width; j++) {
int index = i * output_width + j;
output_data[index] =
input_data[(2 * i) * input_width + 2 * j] * w00 +
input_data[(2 * i) * input_width + 2 * j + 1] * w01 +
input_data[(2 * i) * input_width + 2 * j + 2] * w02 +
input_data[(2 * i + 1) * input_width + 2 * j] * w10 +
input_data[(2 * i + 1) * input_width + 2 * j + 1] * w11 +
input_data[(2 * i + 1) * input_width + 2 * j + 2] * w12 +
input_data[(2 * i + 2) * input_width + 2 * j] * w20 +
input_data[(2 * i + 2) * input_width + 2 * j + 1] * w21 +
input_data[(2 * i + 2) * input_width + 2 * j + 2] * w22;
if (if_bias) {
output_data[index] += *bias_data;
}
if (if_relu) {
output_data[index] =
output_data[index] < 0 ? 0 : output_data[index];
// pad right
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t acc0;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
} else {
acc0 = vmulq_f32(row0, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
float sum0 = vgetq_lane_f32(acc0, 0);
if (padding == 1) {
sum0 += vgetq_lane_f32(acc0, 1);
}
*output_ptr0 = sum0;
}
output_ptr0++;
}
}
}
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
}
}
#endif
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif // __ARM_NEON__
......@@ -23,48 +23,6 @@ namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu);
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input,
......
......@@ -31,16 +31,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
// return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
// }
CPUInfo *info = CPUInfo::Info();
GemmExecutor<SgemmStrategy> exec(info, transA, transB, M, N, K);
GemmExecutor<SgemmStrategy> exec(transA, transB, M, N, K);
exec(alpha, A, lda, B, ldb, beta, C, ldc);
}
void cblas_sgemv(const bool trans, const int M, const int N, const float alpha,
const float *A, const int lda, const float *B,
const float beta, float *C) {
CPUInfo *info = CPUInfo::Info();
GemvExecutor<SgemvStrategy> exec(info, trans, M, N);
GemvExecutor<SgemvStrategy> exec(trans, M, N);
exec(alpha, A, lda, B, beta, C);
}
......
......@@ -19,7 +19,6 @@ limitations under the License. */
#include <omp.h>
#endif
#include <sys/time.h>
#include <iostream>
#include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm/cpu_info.h"
......@@ -29,6 +28,8 @@ namespace paddle_mobile {
namespace operators {
namespace math {
static CPUInfo *info = CPUInfo::Info();
int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; }
unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num,
const int N, const int K) {
......@@ -62,15 +63,9 @@ class GemmExecutor : public Executor {
typedef typename Strategy::Otype Otype;
public:
GemmExecutor(const CPUInfo *info, const bool transA, const bool transB,
const int M, const int N, const int K)
: Executor(),
info_(info),
transA_(transA),
transB_(transB),
M_(M),
N_(N),
K_(K) {
GemmExecutor(const bool transA, const bool transB, const int M, const int N,
const int K)
: Executor(), transA_(transA), transB_(transB), M_(M), N_(N), K_(K) {
unsigned int L1_size = 0;
unsigned int L2_size = 0;
if (M_ > N_) {
......@@ -212,8 +207,6 @@ class GemmExecutor : public Executor {
virtual ~GemmExecutor() {}
private:
const CPUInfo *info_;
const unsigned int M_;
const unsigned int N_;
const unsigned int K_;
......@@ -242,8 +235,8 @@ class GemvExecutor : public Executor {
typedef typename Strategy::Otype Otype;
public:
GemvExecutor(const CPUInfo *info, const bool transA, const int M, const int N)
: Executor(), info_(info), M_(M), N_(N) {}
GemvExecutor(const bool transA, const int M, const int N)
: Executor(), M_(M), N_(N) {}
void operator()(const float alpha, const Itype *A, const int lda,
const Itype *B, const float beta, Otype *C) {
......@@ -253,8 +246,6 @@ class GemvExecutor : public Executor {
virtual ~GemvExecutor() {}
private:
const CPUInfo *const info_;
const unsigned int M_;
const unsigned int N_;
......
......@@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
for (int i = start_height; i < end_height; i += stride_h) {
if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(float));
// memcpy(col_data, im_data, extract * sizeof(float));
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4_t img = vld1q_f32(im_data + s);
vst1q_f32(col_data + s, img);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s];
}
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
......@@ -109,325 +119,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
const float *im_data = im.data<float>();
float *col_data = col->data<float>();
#if __ARM_NEON
const int osize = col_height;
const int isize = im_height;
bool pad1 = padding[0] > 0;
bool pad2 =
(pad1 && padding[1] &&
(((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0));
int fill = isize % 2;
if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 &&
dilation[0] == 1 && im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
float *col0 = col_data + 0 * oosize + 2 * osize + 2;
float *col1 = col_data + 1 * oosize + 2 * osize + 1;
float *col2 = col_data + 2 * oosize + 2 * osize;
float *col3 = col_data + 3 * oosize + osize + 2;
float *col4 = col_data + 4 * oosize + osize + 1;
float *col5 = col_data + 5 * oosize + osize;
float *col6 = col_data + 6 * oosize + 2;
float *col7 = col_data + 7 * oosize + 1;
float *col8 = col_data + 8 * oosize;
float32x4_t im1;
const float *im_tmp_data = im_data + osize + 1;
int rrsize = oosize - osize - 1;
int nr4 = rrsize / 4;
int mr4 = rrsize % 4;
for (int i = 0; i < nr4; ++i) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
vst1q_f32(col6, im1);
vst1q_f32(col7, im1);
vst1q_f32(col8, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mr4; ++i) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
*col6 = *im_tmp_data;
*col7 = *im_tmp_data;
*col8 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data++;
}
im_tmp_data = im_data + 1;
col0 = col_data + 0 * oosize + osize + 2;
col1 = col_data + 1 * oosize + osize + 1;
col2 = col_data + 2 * oosize + osize;
col3 = col_data + 3 * oosize + 2;
col4 = col_data + 4 * oosize + 1;
col5 = col_data + 5 * oosize;
for (int i = 0; i < nk4; i++) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mk4; i++) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
im_tmp_data++;
}
// fill 0 1 11;
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
col_data[0 * oosize + osize + 1] = im_data[0];
col_data[3 * oosize + 1] = im_data[0];
col_data[6 * oosize + 1] = im_data[osize];
col_data[1 * oosize + osize] = im_data[0];
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[osize];
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 &&
im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
// 3 2 3 1 0 1 3 2 3
float *col0 = col_data + 0 * oosize + osize + 1;
float *col1 = col_data + 1 * oosize + osize;
float *col2 = col_data + 2 * oosize + osize;
float *col3 = col_data + 3 * oosize + 1;
float *col4 = col_data + 4 * oosize;
float *col5 = col_data + 5 * oosize;
float *col6 = col_data + 6 * oosize + 1;
float *col7 = col_data + 7 * oosize;
float *col8 = col_data + 8 * oosize;
float32x4x2_t im01;
float32x4x2_t im23;
const float *im_tmp_data0 = im_data;
const float *im_tmp_data2 = im_data + isize;
for (int j = 0; j < osize; ++j) {
for (int i = 0; i < nk4; ++i) {
im01 = vld2q_f32(im_tmp_data0);
im23 = vld2q_f32(im_tmp_data2);
vst1q_f32(col0, im23.val[1]);
vst1q_f32(col1, im23.val[0]);
vst1q_f32(col2, im23.val[1]);
vst1q_f32(col3, im01.val[1]);
vst1q_f32(col4, im01.val[0]);
vst1q_f32(col5, im01.val[1]);
vst1q_f32(col6, im23.val[1]);
vst1q_f32(col7, im23.val[0]);
vst1q_f32(col8, im23.val[1]);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data0 += 8;
im_tmp_data2 += 8;
}
const float *im_tmp_data1 = im_tmp_data0 + 1;
const float *im_tmp_data3 = im_tmp_data2 + 1;
for (int i = 0; i < mk4; ++i) {
*col0 = *im_tmp_data3;
*col1 = *im_tmp_data2;
*col2 = *im_tmp_data3;
*col3 = *im_tmp_data1;
*col4 = *im_tmp_data0;
*col5 = *im_tmp_data1;
*col6 = *im_tmp_data3;
*col7 = *im_tmp_data2;
*col8 = *im_tmp_data3;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data0 += 2;
im_tmp_data1 += 2;
im_tmp_data2 += 2;
im_tmp_data3 += 2;
}
im_tmp_data0 += (isize - fill);
im_tmp_data2 += (isize - fill);
}
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
if (pad2) {
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
}
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
if (pad2) {
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
}
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
if (pad2) {
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
}
col_data[1 * oosize + osize] = im_data[isize];
for (int i = 1; i < osize; ++i) {
col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1];
}
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[isize];
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
......
......@@ -441,10 +441,8 @@ class ConvParam : public OpParam {
enum ExecMode {
EXEC_INVALID = 0,
EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT,
EXEC_DEPTHWISE3x3S2P0_FLOAT,
EXEC_DEPTHWISE3x3S2P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_DEPTHWISE3x3S1_FLOAT,
EXEC_DEPTHWISE3x3S2_FLOAT,
EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT,
EXEC_DEPTHWISE5x5_FLOAT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册