提交 227fdfb6 编写于 作者: H hedaoyuan

Refine NeonDepthwiseConvFunction.

上级 f00c4112
...@@ -509,10 +509,9 @@ public: ...@@ -509,10 +509,9 @@ public:
size_t filterMultiplier = outputChannels / groups_; size_t filterMultiplier = outputChannels / groups_;
CHECK_EQ(inputChannels, groups_); CHECK_EQ(inputChannels, groups_);
// only support // only support strideH() == strideW() and filterHeight == filterWidth.
CHECK_EQ(strideH(), strideW()); CHECK_EQ(strideH(), strideW());
CHECK_EQ(filterHeight, filterWidth); CHECK_EQ(filterHeight, filterWidth);
CHECK_LT(strideH(), size_t(3));
float* inputData = inputs[0].data<float>(); float* inputData = inputs[0].data<float>();
float* filterData = inputs[1].data<float>(); float* filterData = inputs[1].data<float>();
...@@ -538,49 +537,32 @@ public: ...@@ -538,49 +537,32 @@ public:
inputWidth += 2 * paddingW(); inputWidth += 2 * paddingW();
} }
for (size_t i = 0; i < batchSize; i++) { std::function<void(
if (filterWidth == 3 && strideH() == 1) { const float*, const float*, int, int, int, int, int, int, float*)>
DepthwiseConvKernel<3, 1>::run(inputPadding, DepthWiseConv;
filterData,
inputHeight, if (filterWidth == 3 && strideW() == 1) {
inputWidth, DepthWiseConv = DepthwiseConvKernel<3, 1>::run;
outputChannels, } else if (filterWidth == 3 && strideW() == 2) {
outputHeight, DepthWiseConv = DepthwiseConvKernel<3, 2>::run;
outputWidth, } else if (filterWidth == 4 && strideW() == 1) {
filterMultiplier, DepthWiseConv = DepthwiseConvKernel<4, 1>::run;
outputData); } else if (filterWidth == 4 && strideW() == 2) {
} else if (filterWidth == 3 && strideH() == 2) { DepthWiseConv = DepthwiseConvKernel<4, 2>::run;
DepthwiseConvKernel<3, 2>::run(inputPadding, } else {
filterData, LOG(FATAL) << "Not supported";
inputHeight, }
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
} else if (filterWidth == 4 && strideH() == 1) {
DepthwiseConvKernel<4, 1>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
} else if (filterWidth == 4 && strideH() == 2) {
DepthwiseConvKernel<4, 2>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
}
for (size_t i = 0; i < batchSize; i++) {
DepthWiseConv(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
inputPadding += inputChannels * inputHeight * inputWidth; inputPadding += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth; outputData += outputChannels * outputHeight * outputWidth;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册