提交 227fdfb6 编写于 作者: H hedaoyuan

Refine NeonDepthwiseConvFunction.

上级 f00c4112
......@@ -509,10 +509,9 @@ public:
size_t filterMultiplier = outputChannels / groups_;
CHECK_EQ(inputChannels, groups_);
// only support
// only support strideH() == strideW() and filterHeight == filterWidth.
CHECK_EQ(strideH(), strideW());
CHECK_EQ(filterHeight, filterWidth);
CHECK_LT(strideH(), size_t(3));
float* inputData = inputs[0].data<float>();
float* filterData = inputs[1].data<float>();
......@@ -538,49 +537,32 @@ public:
inputWidth += 2 * paddingW();
}
for (size_t i = 0; i < batchSize; i++) {
if (filterWidth == 3 && strideH() == 1) {
DepthwiseConvKernel<3, 1>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
} else if (filterWidth == 3 && strideH() == 2) {
DepthwiseConvKernel<3, 2>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
} else if (filterWidth == 4 && strideH() == 1) {
DepthwiseConvKernel<4, 1>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
} else if (filterWidth == 4 && strideH() == 2) {
DepthwiseConvKernel<4, 2>::run(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
}
std::function<void(
const float*, const float*, int, int, int, int, int, int, float*)>
DepthWiseConv;
if (filterWidth == 3 && strideW() == 1) {
DepthWiseConv = DepthwiseConvKernel<3, 1>::run;
} else if (filterWidth == 3 && strideW() == 2) {
DepthWiseConv = DepthwiseConvKernel<3, 2>::run;
} else if (filterWidth == 4 && strideW() == 1) {
DepthWiseConv = DepthwiseConvKernel<4, 1>::run;
} else if (filterWidth == 4 && strideW() == 2) {
DepthWiseConv = DepthwiseConvKernel<4, 2>::run;
} else {
LOG(FATAL) << "Not supported";
}
for (size_t i = 0; i < batchSize; i++) {
DepthWiseConv(inputPadding,
filterData,
inputHeight,
inputWidth,
outputChannels,
outputHeight,
outputWidth,
filterMultiplier,
outputData);
inputPadding += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册