提交 7062be0f 编写于 作者: H hedaoyuan

Add cmake for compile NNPACKConvOp.cpp.

上级 2e02952b
...@@ -48,6 +48,7 @@ option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) ...@@ -48,6 +48,7 @@ option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF)
option(ON_TRAVIS "Exclude special unit test on Travis CI" OFF) option(ON_TRAVIS "Exclude special unit test on Travis CI" OFF)
option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF) option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF)
option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
...@@ -126,6 +127,10 @@ if(WITH_GPU) ...@@ -126,6 +127,10 @@ if(WITH_GPU)
endif(NOT WITH_DSO) endif(NOT WITH_DSO)
endif(WITH_GPU) endif(WITH_GPU)
if(USE_NNPACK)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIB} ${PTHREADPOOL_LIB} "rt")
endif(USE_NNPACK)
add_subdirectory(proto) add_subdirectory(proto)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(python) add_subdirectory(python)
......
...@@ -10,6 +10,11 @@ if(WITH_GPU) ...@@ -10,6 +10,11 @@ if(WITH_GPU)
cuda_compile(cu_objs ${cu_files}) cuda_compile(cu_objs ${cu_files})
endif() endif()
if(USE_NNPACK)
include(nnpack/nnpack.cmake)
list(APPEND cpp_files nnpack/NNPACKConvOp.cpp)
endif()
add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) add_library(paddle_function STATIC ${cpp_files} ${cu_objs})
add_dependencies(paddle_function ${external_project_dependencies}) add_dependencies(paddle_function ${external_project_dependencies})
add_dependencies(paddle_function gen_proto_cpp) add_dependencies(paddle_function gen_proto_cpp)
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "ConvOp.h"
#include "nnpack.h" #include "nnpack.h"
#include "paddle/function/ConvOp.h"
DEFINE_bool(nnpack_allocate_outside, DEFINE_bool(nnpack_allocate_outside,
false, false,
...@@ -72,14 +72,22 @@ public: ...@@ -72,14 +72,22 @@ public:
} }
} }
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& output = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& input = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
check(inputs, outputs);
const TensorShape& input = inputs[0].shape(); const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape(); const TensorShape& output = outputs[0].shape();
check(input, filter, output);
size_t batchSize = input[0]; size_t batchSize = input[0];
size_t inputChannels = input[1]; size_t inputChannels = input[1];
...@@ -92,12 +100,13 @@ public: ...@@ -92,12 +100,13 @@ public:
// size_t outputWidth = output[3]; // size_t outputWidth = output[3];
nnp_size inputSize = {.width = inputWidth, .height = inputHeight}; nnp_size inputSize = {.width = inputWidth, .height = inputHeight};
nnp_padding padding = {.top = paddingH(), nnp_padding padding = {.top = (size_t)paddingH(),
.right = paddingW(), .right = (size_t)paddingW(),
.bottom = paddingH(), .bottom = (size_t)paddingH(),
.left = paddingW()}; .left = (size_t)paddingW()};
nnp_size kernelSize = {.width = filterWidth, .height = filterHeight}; nnp_size kernelSize = {.width = filterWidth, .height = filterHeight};
nnp_size outputSubsampling = {.width = strideW(), .height = strideH()}; nnp_size outputSubsampling = {.width = (size_t)strideW(),
.height = (size_t)strideH()};
float* inputData = inputs[0].data<float>(); float* inputData = inputs[0].data<float>();
float* filterData = inputs[1].data<float>(); float* filterData = inputs[1].data<float>();
...@@ -129,7 +138,8 @@ public: ...@@ -129,7 +138,8 @@ public:
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
} else { } else {
// only supports stride = 1 // only supports stride = 1
CHECK_EQ(stride_, 1); CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1);
nnp_status status = nnp_convolution_output(algorithm_, nnp_status status = nnp_convolution_output(algorithm_,
batchSize, batchSize,
inputChannels, inputChannels,
...@@ -189,7 +199,8 @@ public: ...@@ -189,7 +199,8 @@ public:
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
} else { } else {
// only supports stride = 1 // only supports stride = 1
CHECK_EQ(stride_, 1); CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1);
nnp_status status = nnp_convolution_output(algorithm_, nnp_status status = nnp_convolution_output(algorithm_,
batchSize, batchSize,
inputChannels, inputChannels,
......
# Find the NNPACK library
# NNPACK_ROOT - where to find NNPACK include and library.
#
set(NNPACK_FOUND OFF)
set(NNPACK_ROOT $ENV{NNPACK_ROOT} CACHE PATH "Folder contains NNPACK")
find_path(NNPACK_INC_DIR nnpack.h PATHS ${NNPACK_ROOT}/include)
find_library(NNPACK_LIB NAMES nnpack PATHS ${NNPACK_ROOT}/lib)
find_library(PTHREADPOOL_LIB NAMES pthreadpool PATHS ${NNPACK_ROOT}/lib)
if(NNPACK_INC_DIR AND NNPACK_LIB AND PTHREADPOOL_LIB)
set(NNPACK_FOUND ON)
INCLUDE_DIRECTORIES(${NNPACK_INC_DIR})
else()
message(FATAL_ERROR "Cannot find NNPACK in (${NNPACK_ROOT})")
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册