提交 b7885b08 编写于 作者: H hedaoyuan

Add DepthwiseConvKernel for filter size is 4.

上级 0dffe68c
......@@ -38,6 +38,22 @@ inline float32_t conv3x3(float32x4_t r0,
return vaddvq_f32(tmp);
}
inline float32_t conv4x4(float32x4_t r0,
float32x4_t r1,
float32x4_t r2,
float32x4_t r3,
float32x4_t k0,
float32x4_t k1,
float32x4_t k2,
float32x4_t k3) {
float32x4_t tmp;
tmp = vmulq_f32(r0, k0);
tmp = vmlaq_f32(tmp, r1, k1);
tmp = vmlaq_f32(tmp, r2, k2);
tmp = vmlaq_f32(tmp, r3, k3);
return vaddvq_f32(tmp);
}
/**
* Each step calculates four elements of the output.
* First step:
......@@ -137,6 +153,114 @@ struct DepthwiseConvKernel<3, 1> {
}
};
/**
* Each step calculates four elements of the output.
*/
template <>
struct DepthwiseConvKernel<4, 1> {
static void run(const float* inputData,
const float* filterData,
int inputHeight,
int inputWidth,
int outputChannels,
int outputHeight,
int outputWidth,
int filterMultiplier,
float* outputData) {
const int steps = outputWidth >> 2;
const int remain = outputWidth & 3;
for (int c = 0; c < outputChannels; c++, filterData += 16) {
// Load the filters
float32x4_t k[4];
k[0] = vld1q_f32(filterData);
k[1] = vld1q_f32(filterData + 4);
k[2] = vld1q_f32(filterData + 8);
k[3] = vld1q_f32(filterData + 12);
const float* r0 =
inputData + (c / filterMultiplier) * (inputHeight * inputWidth);
const float* r1 = r0 + inputWidth;
const float* r2 = r0 + inputWidth * 2;
const float* r3 = r0 + inputWidth * 3;
float32x4_t input[4][4];
for (int h = 0; h < outputHeight; h++) {
for (int s = 0; s < steps; s++) {
// Load the inputs
float32x4_t tmp;
input[0][0] = vld1q_f32(r0);
tmp = vld1q_f32(r0 + 4);
input[0][1] = vextq_f32(input[0][0], tmp, 1);
input[0][2] = vextq_f32(input[0][0], tmp, 2);
input[0][3] = vextq_f32(input[0][0], tmp, 3);
input[1][0] = vld1q_f32(r1);
tmp = vld1q_f32(r1 + 4);
input[1][1] = vextq_f32(input[1][0], tmp, 1);
input[1][2] = vextq_f32(input[1][0], tmp, 2);
input[1][3] = vextq_f32(input[1][0], tmp, 3);
input[2][0] = vld1q_f32(r2);
tmp = vld1q_f32(r2 + 4);
input[2][1] = vextq_f32(input[2][0], tmp, 1);
input[2][2] = vextq_f32(input[2][0], tmp, 2);
input[2][3] = vextq_f32(input[2][0], tmp, 3);
input[3][0] = vld1q_f32(r3);
tmp = vld1q_f32(r3 + 4);
input[3][1] = vextq_f32(input[3][0], tmp, 1);
input[3][2] = vextq_f32(input[3][0], tmp, 2);
input[3][3] = vextq_f32(input[3][0], tmp, 3);
float32x4_t tmp1 = vdupq_n_f32(0.f);
float32x4_t tmp2 = vdupq_n_f32(0.f);
tmp1 = vmlaq_laneq_f32(tmp1, input[0][0], k[0], 0);
tmp2 = vmlaq_laneq_f32(tmp2, input[0][1], k[0], 1);
tmp1 = vmlaq_laneq_f32(tmp1, input[0][2], k[0], 2);
tmp2 = vmlaq_laneq_f32(tmp2, input[0][3], k[0], 3);
tmp1 = vmlaq_laneq_f32(tmp1, input[1][0], k[1], 0);
tmp2 = vmlaq_laneq_f32(tmp2, input[1][1], k[1], 1);
tmp1 = vmlaq_laneq_f32(tmp1, input[1][2], k[1], 2);
tmp2 = vmlaq_laneq_f32(tmp2, input[1][3], k[1], 3);
tmp1 = vmlaq_laneq_f32(tmp1, input[2][0], k[2], 0);
tmp2 = vmlaq_laneq_f32(tmp2, input[2][1], k[2], 1);
tmp1 = vmlaq_laneq_f32(tmp1, input[2][2], k[2], 2);
tmp2 = vmlaq_laneq_f32(tmp2, input[2][3], k[2], 3);
tmp1 = vmlaq_laneq_f32(tmp1, input[3][0], k[3], 0);
tmp2 = vmlaq_laneq_f32(tmp2, input[3][1], k[3], 1);
tmp1 = vmlaq_laneq_f32(tmp1, input[3][2], k[3], 2);
tmp2 = vmlaq_laneq_f32(tmp2, input[3][3], k[3], 3);
tmp1 = vaddq_f32(tmp1, tmp2);
vst1q_f32(outputData, tmp1);
r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
outputData += 4;
}
for (int r = 0; r < remain; r++) {
float32x4_t i0 = vld1q_f32(r0);
float32x4_t i1 = vld1q_f32(r1);
float32x4_t i2 = vld1q_f32(r2);
float32x4_t i3 = vld1q_f32(r3);
*outputData = conv4x4(i0, i1, i2, i3, k[0], k[1], k[2], k[3]);
r0++;
r1++;
r2++;
r3++;
outputData++;
}
r0 += 3;
r1 += 3;
r2 += 3;
r3 += 3;
}
}
}
};
template <DeviceType Device>
class NeonDepthwiseConvFunction : public ConvFunctionBase {
public:
......@@ -175,7 +299,6 @@ public:
// only support
CHECK_EQ(strideH(), strideW());
CHECK_EQ(filterHeight, filterWidth);
CHECK_EQ(filterHeight, size_t(3));
CHECK_LT(strideH(), size_t(3));
float* inputData = inputs[0].data<float>();
......@@ -203,6 +326,7 @@ public:
}
for (size_t i = 0; i < batchSize; i++) {
if (filterWidth == 3) {
DepthwiseConvKernel<3, 1>::run(inputPadding,
filterData,
inputHeight,
......@@ -212,6 +336,17 @@ public:
outputWidth,
filterMultiplier,
outputData);
} else if (filterWidth == 4) {
DepthwiseConvKernel<4, 1>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
}
inputPadding += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册