提交 455ad5b5 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #3141 from hedaoyuan/nnpack

Support groups in NNPACKFunction.
......@@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
CHECK_EQ(groups_, (size_t)1);
algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo"));
// algorithm_ = nnp_convolution_algorithm_auto;
transform_strategy_ = nnp_convolution_transform_strategy_compute;
nnp_status status = nnp_initialize();
CHECK_EQ(status, nnp_status_success);
......@@ -67,8 +65,7 @@ public:
}
}
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
......@@ -91,8 +88,8 @@ public:
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
// size_t outputHeight = output[2];
// size_t outputWidth = output[3];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
nnp_size inputSize = {.width = inputWidth, .height = inputHeight};
nnp_padding padding = {.top = (size_t)paddingH(),
......@@ -171,20 +168,25 @@ public:
}
}
size_t inputOffset = inputChannels / groups_ * inputHeight * inputWidth;
size_t outputOffset = outputChannels / groups_ * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
if (batchSize == 1) {
for (size_t g = 0; g < groups_; g++) {
nnp_status status =
nnp_convolution_inference(algorithm_,
transform_strategy_,
inputChannels,
outputChannels,
inputChannels / groups_,
outputChannels / groups_,
inputSize,
padding,
kernelSize,
outputSubsampling,
inputData,
filterData,
inputData + inputOffset * g,
filterData + filterOffset * g,
nullptr, /* bias */
outputData,
outputData + outputOffset * g,
bufferPtr,
sizePtr,
nnp_activation_identity,
......@@ -192,21 +194,24 @@ public:
threadpool_, /* threadpool */
nullptr);
CHECK_EQ(status, nnp_status_success);
}
} else {
for (size_t g = 0; g < groups_; g++) {
// only supports stride = 1
CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1);
nnp_status status = nnp_convolution_output(algorithm_,
nnp_status status =
nnp_convolution_output(algorithm_,
batchSize,
inputChannels,
outputChannels,
inputChannels / groups_,
outputChannels / groups_,
inputSize,
padding,
kernelSize,
inputData,
filterData,
inputData + inputOffset * g,
filterData + filterOffset * g,
nullptr, /* bias */
outputData,
outputData + outputOffset * g,
bufferPtr,
sizePtr,
nnp_activation_identity,
......@@ -216,6 +221,7 @@ public:
CHECK_EQ(status, nnp_status_success);
}
}
}
static void create_nnpack_threadpool() {
if (FLAGS_nnpack_num_threads && threadpool_ == nullptr) {
......
......@@ -57,8 +57,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
convGradFilterType = "GemmConvGradFilter";
}
if (FLAGS_use_nnpack) {
CHECK_EQ(isDeconv_, false);
if (FLAGS_use_nnpack && !isDeconv_) {
createFunction(forward_,
"NNPACKConv",
FuncConfig()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册