提交 227fdfb6 编写于 作者: H hedaoyuan

Refine NeonDepthwiseConvFunction.

上级 f00c4112
...@@ -509,10 +509,9 @@ public: ...@@ -509,10 +509,9 @@ public:
size_t filterMultiplier = outputChannels / groups_; size_t filterMultiplier = outputChannels / groups_;
CHECK_EQ(inputChannels, groups_); CHECK_EQ(inputChannels, groups_);
// only support // only support strideH() == strideW() and filterHeight == filterWidth.
CHECK_EQ(strideH(), strideW()); CHECK_EQ(strideH(), strideW());
CHECK_EQ(filterHeight, filterWidth); CHECK_EQ(filterHeight, filterWidth);
CHECK_LT(strideH(), size_t(3));
float* inputData = inputs[0].data<float>(); float* inputData = inputs[0].data<float>();
float* filterData = inputs[1].data<float>(); float* filterData = inputs[1].data<float>();
...@@ -538,39 +537,24 @@ public: ...@@ -538,39 +537,24 @@ public:
inputWidth += 2 * paddingW(); inputWidth += 2 * paddingW();
} }
std::function<void(
const float*, const float*, int, int, int, int, int, int, float*)>
DepthWiseConv;
if (filterWidth == 3 && strideW() == 1) {
DepthWiseConv = DepthwiseConvKernel<3, 1>::run;
} else if (filterWidth == 3 && strideW() == 2) {
DepthWiseConv = DepthwiseConvKernel<3, 2>::run;
} else if (filterWidth == 4 && strideW() == 1) {
DepthWiseConv = DepthwiseConvKernel<4, 1>::run;
} else if (filterWidth == 4 && strideW() == 2) {
DepthWiseConv = DepthwiseConvKernel<4, 2>::run;
} else {
LOG(FATAL) << "Not supported";
}
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
if (filterWidth == 3 && strideH() == 1) { DepthWiseConv(inputPadding,
DepthwiseConvKernel<3, 1>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
} else if (filterWidth == 3 && strideH() == 2) {
DepthwiseConvKernel<3, 2>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
} else if (filterWidth == 4 && strideH() == 1) {
DepthwiseConvKernel<4, 1>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
} else if (filterWidth == 4 && strideH() == 2) {
DepthwiseConvKernel<4, 2>::run(inputPadding,
filterData, filterData,
inputHeight, inputHeight,
inputWidth, inputWidth,
...@@ -579,8 +563,6 @@ public: ...@@ -579,8 +563,6 @@ public:
outputWidth, outputWidth,
filterMultiplier, filterMultiplier,
outputData); outputData);
}
inputPadding += inputChannels * inputHeight * inputWidth; inputPadding += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth; outputData += outputChannels * outputHeight * outputWidth;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册