提交 da616a6f 编写于 作者: H hedaoyuan

Fix some bugs.

上级 370dcf76
......@@ -202,9 +202,10 @@ void DepthwiseConvolution(const std::string& conv1,
for (size_t outputChannels : {32, 64}) {
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
// NNPACK only supports stride = 1 if batchSize > 1
// NNPACK only supports stride = 1 if batchSize > 1,
// and there has some bug when batchSize > 1 and groups != 1
if ((conv1 == "NNPACKConv-CPU" || conv2 == "NNPACKConv-CPU") &&
batchSize > 1 && stride > 1)
batchSize > 1)
break;
size_t outputSize =
......
......@@ -201,20 +201,18 @@ public:
CHECK_EQ(strideW(), 1);
// TODO(hedaoyuan): There has some bug when batchSize > 1 and groups_ > 1.
CHECK_EQ(groups_, (size_t)1);
for (size_t g = 0; g < groups_; g++) {
nnp_status status =
nnp_convolution_output(algorithm_,
CHECK_EQ(groups_, static_cast<size_t>(1));
nnp_status status = nnp_convolution_output(algorithm_,
batchSize,
inputChannels / groups_,
outputChannels / groups_,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
inputData + inputOffset * g,
filterData + filterOffset * g,
inputData,
filterData,
nullptr, /* bias */
outputData + outputOffset * g,
outputData,
bufferPtr,
sizePtr,
nnp_activation_identity,
......@@ -224,7 +222,6 @@ public:
CHECK_EQ(status, nnp_status_success);
}
}
}
static void create_nnpack_threadpool() {
if (FLAGS_nnpack_num_threads && threadpool_ == nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册