提交 455ad5b5 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #3141 from hedaoyuan/nnpack

Support groups in NNPACKFunction.
...@@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase { ...@@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
CHECK_EQ(groups_, (size_t)1);
algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo")); algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo"));
// algorithm_ = nnp_convolution_algorithm_auto;
transform_strategy_ = nnp_convolution_transform_strategy_compute; transform_strategy_ = nnp_convolution_transform_strategy_compute;
nnp_status status = nnp_initialize(); nnp_status status = nnp_initialize();
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
...@@ -67,8 +65,7 @@ public: ...@@ -67,8 +65,7 @@ public:
} }
} }
virtual void check(const BufferArgs& inputs, void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape(); const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape(); const TensorShape& output = outputs[0].shape();
...@@ -91,8 +88,8 @@ public: ...@@ -91,8 +88,8 @@ public:
size_t filterHeight = getFilterHeight(filter); size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter); size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1]; size_t outputChannels = output[1];
// size_t outputHeight = output[2]; size_t outputHeight = output[2];
// size_t outputWidth = output[3]; size_t outputWidth = output[3];
nnp_size inputSize = {.width = inputWidth, .height = inputHeight}; nnp_size inputSize = {.width = inputWidth, .height = inputHeight};
nnp_padding padding = {.top = (size_t)paddingH(), nnp_padding padding = {.top = (size_t)paddingH(),
...@@ -171,49 +168,58 @@ public: ...@@ -171,49 +168,58 @@ public:
} }
} }
size_t inputOffset = inputChannels / groups_ * inputHeight * inputWidth;
size_t outputOffset = outputChannels / groups_ * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
if (batchSize == 1) { if (batchSize == 1) {
nnp_status status = for (size_t g = 0; g < groups_; g++) {
nnp_convolution_inference(algorithm_, nnp_status status =
transform_strategy_, nnp_convolution_inference(algorithm_,
inputChannels, transform_strategy_,
outputChannels, inputChannels / groups_,
inputSize, outputChannels / groups_,
padding, inputSize,
kernelSize, padding,
outputSubsampling, kernelSize,
inputData, outputSubsampling,
filterData, inputData + inputOffset * g,
nullptr, /* bias */ filterData + filterOffset * g,
outputData, nullptr, /* bias */
bufferPtr, outputData + outputOffset * g,
sizePtr, bufferPtr,
nnp_activation_identity, sizePtr,
nullptr, nnp_activation_identity,
threadpool_, /* threadpool */ nullptr,
nullptr); threadpool_, /* threadpool */
CHECK_EQ(status, nnp_status_success); nullptr);
CHECK_EQ(status, nnp_status_success);
}
} else { } else {
// only supports stride = 1 for (size_t g = 0; g < groups_; g++) {
CHECK_EQ(strideH(), 1); // only supports stride = 1
CHECK_EQ(strideW(), 1); CHECK_EQ(strideH(), 1);
nnp_status status = nnp_convolution_output(algorithm_, CHECK_EQ(strideW(), 1);
batchSize, nnp_status status =
inputChannels, nnp_convolution_output(algorithm_,
outputChannels, batchSize,
inputSize, inputChannels / groups_,
padding, outputChannels / groups_,
kernelSize, inputSize,
inputData, padding,
filterData, kernelSize,
nullptr, /* bias */ inputData + inputOffset * g,
outputData, filterData + filterOffset * g,
bufferPtr, nullptr, /* bias */
sizePtr, outputData + outputOffset * g,
nnp_activation_identity, bufferPtr,
nullptr, sizePtr,
threadpool_, /* threadpool */ nnp_activation_identity,
nullptr); nullptr,
CHECK_EQ(status, nnp_status_success); threadpool_, /* threadpool */
nullptr);
CHECK_EQ(status, nnp_status_success);
}
} }
} }
......
...@@ -57,8 +57,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, ...@@ -57,8 +57,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
convGradFilterType = "GemmConvGradFilter"; convGradFilterType = "GemmConvGradFilter";
} }
if (FLAGS_use_nnpack) { if (FLAGS_use_nnpack && !isDeconv_) {
CHECK_EQ(isDeconv_, false);
createFunction(forward_, createFunction(forward_,
"NNPACKConv", "NNPACKConv",
FuncConfig() FuncConfig()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册