提交 4b6b7251 编写于 作者: H hedaoyuan

Refine NeonDepthwiseConv.

上级 f7e75a03
...@@ -18,8 +18,6 @@ limitations under the License. */ ...@@ -18,8 +18,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace neon {
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
template <DeviceType Device> template <DeviceType Device>
...@@ -45,16 +43,16 @@ public: ...@@ -45,16 +43,16 @@ public:
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape(); const TensorShape& output = outputs[0].shape();
size_t batchSize = input[0]; int batchSize = input[0];
size_t inputChannels = input[1]; int inputChannels = input[1];
size_t inputHeight = input[2]; int inputHeight = input[2];
size_t inputWidth = input[3]; int inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter); int filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter); int filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1]; int outputChannels = output[1];
size_t outputHeight = output[2]; int outputHeight = output[2];
size_t outputWidth = output[3]; int outputWidth = output[3];
size_t filterMultiplier = outputChannels / groups_; int filterMultiplier = outputChannels / groups_;
CHECK_EQ(inputChannels, groups_); CHECK_EQ(inputChannels, groups_);
// only support strideH() == strideW() and filterHeight == filterWidth. // only support strideH() == strideW() and filterHeight == filterWidth.
...@@ -90,18 +88,18 @@ public: ...@@ -90,18 +88,18 @@ public:
DepthWiseConv; DepthWiseConv;
if (filterWidth == 3 && strideW() == 1) { if (filterWidth == 3 && strideW() == 1) {
DepthWiseConv = DepthwiseConvKernel<3, 1>::run; DepthWiseConv = neon::DepthwiseConvKernel<3, 1>::run;
} else if (filterWidth == 3 && strideW() == 2) { } else if (filterWidth == 3 && strideW() == 2) {
DepthWiseConv = DepthwiseConvKernel<3, 2>::run; DepthWiseConv = neon::DepthwiseConvKernel<3, 2>::run;
} else if (filterWidth == 4 && strideW() == 1) { } else if (filterWidth == 4 && strideW() == 1) {
DepthWiseConv = DepthwiseConvKernel<4, 1>::run; DepthWiseConv = neon::DepthwiseConvKernel<4, 1>::run;
} else if (filterWidth == 4 && strideW() == 2) { } else if (filterWidth == 4 && strideW() == 2) {
DepthWiseConv = DepthwiseConvKernel<4, 2>::run; DepthWiseConv = neon::DepthwiseConvKernel<4, 2>::run;
} else { } else {
LOG(FATAL) << "Not supported"; LOG(FATAL) << "Not supported";
} }
for (size_t i = 0; i < batchSize; i++) { for (int i = 0; i < batchSize; i++) {
DepthWiseConv(inputPadding, DepthWiseConv(inputPadding,
filterData, filterData,
inputHeight, inputHeight,
...@@ -117,9 +115,10 @@ public: ...@@ -117,9 +115,10 @@ public:
} }
}; };
#ifndef PADDLE_TYPE_DOUBLE
REGISTER_TYPED_FUNC(NeonDepthwiseConv, CPU, NeonDepthwiseConvFunction); REGISTER_TYPED_FUNC(NeonDepthwiseConv, CPU, NeonDepthwiseConvFunction);
#endif
#endif #endif
} // namespace neon
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册