提交 953f8ddf 编写于 作者: H hedaoyuan

Support groups in NNPACKFunction.

上级 0973c2c9
...@@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase { ...@@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
CHECK_EQ(groups_, (size_t)1);
algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo")); algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo"));
// algorithm_ = nnp_convolution_algorithm_auto;
transform_strategy_ = nnp_convolution_transform_strategy_compute; transform_strategy_ = nnp_convolution_transform_strategy_compute;
nnp_status status = nnp_initialize(); nnp_status status = nnp_initialize();
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
...@@ -67,8 +65,7 @@ public: ...@@ -67,8 +65,7 @@ public:
} }
} }
virtual void check(const BufferArgs& inputs, void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape(); const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape(); const TensorShape& output = outputs[0].shape();
...@@ -91,8 +88,8 @@ public: ...@@ -91,8 +88,8 @@ public:
size_t filterHeight = getFilterHeight(filter); size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter); size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1]; size_t outputChannels = output[1];
// size_t outputHeight = output[2]; size_t outputHeight = output[2];
// size_t outputWidth = output[3]; size_t outputWidth = output[3];
nnp_size inputSize = {.width = inputWidth, .height = inputHeight}; nnp_size inputSize = {.width = inputWidth, .height = inputHeight};
nnp_padding padding = {.top = (size_t)paddingH(), nnp_padding padding = {.top = (size_t)paddingH(),
...@@ -171,20 +168,25 @@ public: ...@@ -171,20 +168,25 @@ public:
} }
} }
size_t inputOffset = inputChannels / groups_ * inputHeight * inputWidth;
size_t outputOffset = outputChannels / groups_ * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
if (batchSize == 1) { if (batchSize == 1) {
for (size_t g = 0; g < groups_; g++) {
nnp_status status = nnp_status status =
nnp_convolution_inference(algorithm_, nnp_convolution_inference(algorithm_,
transform_strategy_, transform_strategy_,
inputChannels, inputChannels / groups_,
outputChannels, outputChannels / groups_,
inputSize, inputSize,
padding, padding,
kernelSize, kernelSize,
outputSubsampling, outputSubsampling,
inputData, inputData + inputOffset * g,
filterData, filterData + filterOffset * g,
nullptr, /* bias */ nullptr, /* bias */
outputData, outputData + outputOffset * g,
bufferPtr, bufferPtr,
sizePtr, sizePtr,
nnp_activation_identity, nnp_activation_identity,
...@@ -192,21 +194,24 @@ public: ...@@ -192,21 +194,24 @@ public:
threadpool_, /* threadpool */ threadpool_, /* threadpool */
nullptr); nullptr);
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
}
} else { } else {
for (size_t g = 0; g < groups_; g++) {
// only supports stride = 1 // only supports stride = 1
CHECK_EQ(strideH(), 1); CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1); CHECK_EQ(strideW(), 1);
nnp_status status = nnp_convolution_output(algorithm_, nnp_status status =
nnp_convolution_output(algorithm_,
batchSize, batchSize,
inputChannels, inputChannels / groups_,
outputChannels, outputChannels / groups_,
inputSize, inputSize,
padding, padding,
kernelSize, kernelSize,
inputData, inputData + inputOffset * g,
filterData, filterData + filterOffset * g,
nullptr, /* bias */ nullptr, /* bias */
outputData, outputData + outputOffset * g,
bufferPtr, bufferPtr,
sizePtr, sizePtr,
nnp_activation_identity, nnp_activation_identity,
...@@ -216,6 +221,7 @@ public: ...@@ -216,6 +221,7 @@ public:
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
} }
} }
}
static void create_nnpack_threadpool() { static void create_nnpack_threadpool() {
if (FLAGS_nnpack_num_threads && threadpool_ == nullptr) { if (FLAGS_nnpack_num_threads && threadpool_ == nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册