提交 d1429ac4 编写于 作者: P peizhilin

add recordio support

上级 5670e9ea
...@@ -190,11 +190,11 @@ include(external/pybind11) # download pybind11 ...@@ -190,11 +190,11 @@ include(external/pybind11) # download pybind11
include(external/cares) include(external/cares)
include(external/cub) include(external/cub)
include(external/xxhash) # download xxhash include(external/xxhash) # download xxhash
if (NOT WIN32)
# there is no official support of snappystream, warpctc, nccl, cupti in windows
include(external/snappy) # download snappy include(external/snappy) # download snappy
include(external/snappystream) # download snappystream include(external/snappystream) # download snappystream
if (NOT WIN32)
# there is no official support of warpctc, nccl, cupti in windows
include(external/warpctc) # download, build, install warpctc include(external/warpctc) # download, build, install warpctc
include(cupti) include(cupti)
endif (NOT WIN32) endif (NOT WIN32)
......
...@@ -16,8 +16,9 @@ if(WITH_AMD_GPU) ...@@ -16,8 +16,9 @@ if(WITH_AMD_GPU)
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/sabreshao/hipeigen.git" # GIT_REPOSITORY "https://github.com/sabreshao/hipeigen.git"
GIT_TAG 0cba03ff9f8f9f70bbd92ac5857b031aa8fed6f9 # GIT_TAG 0cba03ff9f8f9f70bbd92ac5857b031aa8fed6f9
GIT_REPOSITORY "http://admin@172.20.90.14:8080/r/eigen3.git"
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
...@@ -29,10 +30,11 @@ else() ...@@ -29,10 +30,11 @@ else()
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/eigenteam/eigen-git-mirror" # GIT_REPOSITORY "https://github.com/eigenteam/eigen-git-mirror"
GIT_REPOSITORY "http://admin@172.20.90.14:8080/r/eigen3.git"
# eigen on cuda9.1 missing header of math_funtions.hpp # eigen on cuda9.1 missing header of math_funtions.hpp
# https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen # https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen
GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c # GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
DOWNLOAD_NAME "eigen" DOWNLOAD_NAME "eigen"
UPDATE_COMMAND "" UPDATE_COMMAND ""
......
...@@ -28,8 +28,9 @@ INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR}) ...@@ -28,8 +28,9 @@ INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
ExternalProject_Add( ExternalProject_Add(
extern_gflags extern_gflags
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/gflags/gflags.git" # GIT_REPOSITORY "https://github.com/gflags/gflags.git"
GIT_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a GIT_REPOSITORY "http://admin@172.20.90.14:8080/r/gflags.git"
# GIT_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a
PREFIX ${GFLAGS_SOURCES_DIR} PREFIX ${GFLAGS_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -34,13 +34,14 @@ ELSE() ...@@ -34,13 +34,14 @@ ELSE()
SET(GLOG_REPOSITORY "https://github.com/google/glog.git") SET(GLOG_REPOSITORY "https://github.com/google/glog.git")
SET(GLOG_TAG "v0.3.5") SET(GLOG_TAG "v0.3.5")
ENDIF() ENDIF()
SET(GLOG_REPOSITORY "http://admin@172.20.90.14:8080/r/glog.git")
ExternalProject_Add( ExternalProject_Add(
extern_glog extern_glog
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS gflags DEPENDS gflags
GIT_REPOSITORY ${GLOG_REPOSITORY} GIT_REPOSITORY ${GLOG_REPOSITORY}
GIT_TAG ${GLOG_TAG} # GIT_TAG ${GLOG_TAG}
PREFIX ${GLOG_SOURCES_DIR} PREFIX ${GLOG_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -43,8 +43,9 @@ IF(WITH_TESTING) ...@@ -43,8 +43,9 @@ IF(WITH_TESTING)
extern_gtest extern_gtest
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${GTEST_DEPENDS} DEPENDS ${GTEST_DEPENDS}
GIT_REPOSITORY "https://github.com/google/googletest.git" # GIT_REPOSITORY "https://github.com/google/googletest.git"
GIT_TAG "release-1.8.0" GIT_REPOSITORY "http://admin@172.20.90.14:8080/r/gtest.git"
# GIT_TAG "release-1.8.0"
PREFIX ${GTEST_SOURCES_DIR} PREFIX ${GTEST_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -202,8 +202,9 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -202,8 +202,9 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64") SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64")
ENDIF() ENDIF()
SET(PROTOBUF_REPO "https://github.com/google/protobuf.git") # SET(PROTOBUF_REPO "https://github.com/google/protobuf.git")
SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546") # SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546")
SET(PROTOBUF_REPO http://admin@172.20.90.14:8080/r/protobuf.git)
IF(MOBILE_INFERENCE) IF(MOBILE_INFERENCE)
# The reason why the official version is not used is described in # The reason why the official version is not used is described in
# https://github.com/PaddlePaddle/Paddle/issues/6114 # https://github.com/PaddlePaddle/Paddle/issues/6114
......
...@@ -24,7 +24,11 @@ set(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy) ...@@ -24,7 +24,11 @@ set(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy)
set(SNAPPY_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy) set(SNAPPY_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy)
set(SNAPPY_INCLUDE_DIR "${SNAPPY_INSTALL_DIR}/include" CACHE PATH "snappy include directory." FORCE) set(SNAPPY_INCLUDE_DIR "${SNAPPY_INSTALL_DIR}/include" CACHE PATH "snappy include directory." FORCE)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a") if (WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/snappy.lib")
else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32)
ExternalProject_Add( ExternalProject_Add(
extern_snappy extern_snappy
...@@ -34,8 +38,12 @@ ExternalProject_Add( ...@@ -34,8 +38,12 @@ ExternalProject_Add(
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib -DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
......
...@@ -18,36 +18,45 @@ ENDIF() ...@@ -18,36 +18,45 @@ ENDIF()
include (ExternalProject) include (ExternalProject)
# NOTE: snappy is needed when linking with recordio
set(SNAPPYSTREAM_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy_stream) set(SNAPPYSTREAM_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy_stream)
set(SNAPPYSTREAM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy_stream) set(SNAPPYSTREAM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy_stream)
set(SNAPPYSTREAM_INCLUDE_DIR "${SNAPPYSTREAM_INSTALL_DIR}/include" CACHE PATH "snappy stream include directory." FORCE) set(SNAPPYSTREAM_INCLUDE_DIR "${SNAPPYSTREAM_INSTALL_DIR}/include" CACHE PATH "snappy stream include directory." FORCE)
set(SNAPPYSTREAM_LIBRARIES "${SNAPPYSTREAM_INSTALL_DIR}/lib/libsnappystream.a") if(WIN32)
# Fix me, VS2015 come without VLA support
ExternalProject_Add( set(SNAPPYSTREAM_LIBRARIES "${SNAPPYSTREAM_INSTALL_DIR}/lib/snappystream.lib")
extern_snappystream MESSAGE(WARNING, "In windows, snappystream has no compile support for windows,
GIT_REPOSITORY "https://github.com/hoxnox/snappystream.git" please build it manually and put it at " ${SNAPPYSTREAM_INSTALL_DIR})
GIT_TAG "0.2.8" else(WIN32)
PREFIX ${SNAPPYSTREAM_SOURCES_DIR} set(SNAPPYSTREAM_LIBRARIES "${SNAPPYSTREAM_INSTALL_DIR}/lib/libsnappystream.a")
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} ExternalProject_Add(
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} extern_snappystream
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} GIT_REPOSITORY "https://github.com/hoxnox/snappystream.git"
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} GIT_TAG "0.2.8"
-DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR} PREFIX ${SNAPPYSTREAM_SOURCES_DIR}
-DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib UPDATE_COMMAND ""
-DCMAKE_POSITION_INDEPENDENT_CODE=ON CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DSNAPPY_ROOT=${SNAPPY_INSTALL_DIR} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
${EXTERNAL_OPTIONAL_ARGS} -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
CMAKE_CACHE_ARGS -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_INSTALL_PREFIX:PATH=${SNAPPYSTREAM_INSTALL_DIR} -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPYSTREAM_INSTALL_DIR}/lib -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
DEPENDS snappy -DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR}
) -DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DSNAPPY_ROOT=${SNAPPY_INSTALL_DIR}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_INSTALL_PREFIX:PATH=${SNAPPYSTREAM_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPYSTREAM_INSTALL_DIR}/lib
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
DEPENDS snappy
)
endif(WIN32)
add_library(snappystream STATIC IMPORTED GLOBAL) add_library(snappystream STATIC IMPORTED GLOBAL)
set_property(TARGET snappystream PROPERTY IMPORTED_LOCATION ${SNAPPYSTREAM_LIBRARIES}) set_property(TARGET snappystream PROPERTY IMPORTED_LOCATION ${SNAPPYSTREAM_LIBRARIES})
......
...@@ -31,8 +31,9 @@ INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include zl ...@@ -31,8 +31,9 @@ INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include zl
ExternalProject_Add( ExternalProject_Add(
extern_zlib extern_zlib
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/madler/zlib.git" # GIT_REPOSITORY "https://github.com/madler/zlib.git"
GIT_TAG "v1.2.8" GIT_REPOSITORY "http://admin@172.20.90.14:8080/r/zlib.git"
# GIT_TAG "v1.2.8"
PREFIX ${ZLIB_SOURCES_DIR} PREFIX ${ZLIB_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
......
...@@ -3,13 +3,9 @@ add_subdirectory(platform) ...@@ -3,13 +3,9 @@ add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(operators) add_subdirectory(operators)
add_subdirectory(string) add_subdirectory(string)
add_subdirectory(pybind)
if (NOT WIN32)
add_subdirectory(recordio) add_subdirectory(recordio)
endif(NOT WIN32) add_subdirectory(pybind)
# NOTE: please add subdirectory inference at last. # NOTE: please add subdirectory inference at last.
add_subdirectory(inference) add_subdirectory(inference)
add_subdirectory(train) add_subdirectory(train)
...@@ -68,11 +68,7 @@ if(WITH_GPU) ...@@ -68,11 +68,7 @@ if(WITH_GPU)
else() else()
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor) cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
endif() endif()
if (NOT WIN32) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version)
else()
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version)
endif (NOT WIN32)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
......
...@@ -95,7 +95,8 @@ function(op_library TARGET) ...@@ -95,7 +95,8 @@ function(op_library TARGET)
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op" foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op"
"crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op" "crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op"
"fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op" "fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op"
"fusion_seqexpand_concat_fc_op" "attention_lstm_op" "fused_embedding_fc_lstm_op" "fc_op") "fusion_seqexpand_concat_fc_op" "attention_lstm_op" "fused_embedding_fc_lstm_op" "fc_op"
)
if ("${TARGET}" STREQUAL "${windows_unsupport_op}") if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return() return()
endif() endif()
...@@ -225,7 +226,6 @@ if(WITH_DISTRIBUTE) ...@@ -225,7 +226,6 @@ if(WITH_DISTRIBUTE)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY}) SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY})
find_library(RDMACM_LIBRARY NAMES rdmacm) find_library(RDMACM_LIBRARY NAMES rdmacm)
ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL) ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY}) SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY})
...@@ -338,11 +338,7 @@ foreach(src ${GENERAL_OPS}) ...@@ -338,11 +338,7 @@ foreach(src ${GENERAL_OPS})
endforeach() endforeach()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
if (NOT WIN32)
add_subdirectory(reader) add_subdirectory(reader)
endif(NOT WIN32)
foreach(src ${READER_LIBRARY}) foreach(src ${READER_LIBRARY})
set(OP_LIBRARY ${src} ${OP_LIBRARY}) set(OP_LIBRARY ${src} ${OP_LIBRARY})
endforeach() endforeach()
......
...@@ -74,7 +74,7 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase { ...@@ -74,7 +74,7 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase {
"Name of the `LoDTensorBlockingQueueHolder` variable"); "Name of the `LoDTensorBlockingQueueHolder` variable");
AddComment(R"DOC( AddComment(R"DOC(
Create PyReader to support LoDTensor data feeding in Python side. Create PyReader to support LoDTensor data feeding in Python side.
)DOC"); )DOC");
} }
}; };
......
...@@ -35,10 +35,10 @@ class ROIAlignOp : public framework::OperatorWithKernel { ...@@ -35,10 +35,10 @@ class ROIAlignOp : public framework::OperatorWithKernel {
"The format of input tensor is NCHW."); "The format of input tensor is NCHW.");
PADDLE_ENFORCE(rois_dims.size() == 2, PADDLE_ENFORCE(rois_dims.size() == 2,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" "ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ]."); "given as [[x1, y1, x2, y2], ...].");
PADDLE_ENFORCE(rois_dims[1] == 4, PADDLE_ENFORCE(rois_dims[1] == 4,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" "ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ]."); "given as [[x1, y1, x2, y2], ...].");
int pooled_height = ctx->Attrs().Get<int>("pooled_height"); int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width"); int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale"); float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
...@@ -103,7 +103,7 @@ class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -103,7 +103,7 @@ class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor), " "(LoDTensor), "
"ROIs (Regions of Interest) to pool over. " "ROIs (Regions of Interest) to pool over. "
"should be a 2-D LoDTensor of shape (num_rois, 4)" "should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ]. " "given as [[x1, y1, x2, y2], ...]. "
"(x1, y1) is the top left coordinates, and " "(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates."); "(x2, y2) is the bottom right coordinates.");
AddOutput("Out", AddOutput("Out",
......
...@@ -40,10 +40,10 @@ class ROIPoolOp : public framework::OperatorWithKernel { ...@@ -40,10 +40,10 @@ class ROIPoolOp : public framework::OperatorWithKernel {
"The format of input tensor is NCHW."); "The format of input tensor is NCHW.");
PADDLE_ENFORCE(rois_dims.size() == 2, PADDLE_ENFORCE(rois_dims.size() == 2,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" "ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ]."); "given as [[x1, y1, x2, y2], ...].");
PADDLE_ENFORCE(rois_dims[1] == kROISize, PADDLE_ENFORCE(rois_dims[1] == kROISize,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" "ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ]."); "given as [[x1, y1, x2, y2], ...].");
int pooled_height = ctx->Attrs().Get<int>("pooled_height"); int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width"); int pooled_width = ctx->Attrs().Get<int>("pooled_width");
...@@ -110,7 +110,7 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,7 +110,7 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor), " "(LoDTensor), "
"ROIs (Regions of Interest) to pool over. " "ROIs (Regions of Interest) to pool over. "
"should be a 2-D LoDTensor of shape (num_rois, 4)" "should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ]. " "given as [[x1, y1, x2, y2], ...]. "
"Where batch_id is the id of the data, " "Where batch_id is the id of the data, "
"(x1, y1) is the top left coordinates, and " "(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates."); "(x2, y2) is the bottom right coordinates.");
......
...@@ -86,7 +86,7 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,7 +86,7 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
.GreaterThan(1); .GreaterThan(1);
AddComment(R"DOC( AddComment(R"DOC(
reorg operator used in Yolo v2. reorg operator used in Yolo v2.
The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize, The equation is: C2 = C1/blocksize * blocksize, W2 = W1 * blocksize + offset % blocksize, H2 = H1 * blocksize + offset / blocksize,
Reshape Input(X) into the shape according to Attr(blocksize). The Reshape Input(X) into the shape according to Attr(blocksize). The
data in Input(X) are unchanged. data in Input(X) are unchanged.
......
...@@ -132,10 +132,12 @@ static void MkDir(const char *path) { ...@@ -132,10 +132,12 @@ static void MkDir(const char *path) {
} }
} }
#else #else
CreateDirectory(path, NULL); BOOL return_value = CreateDirectory(path, NULL);
auto errorno = GetLastError(); if (!return_value) {
if (errorno != ERROR_ALREADY_EXISTS) { auto errorno = GetLastError();
throw std::runtime_error(path_error); if (errorno != ERROR_ALREADY_EXISTS) {
throw std::runtime_error(path_error);
}
} }
#endif // !_WIN32 #endif // !_WIN32
} }
......
set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder) set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc)
if(NOT WIN32) if(NOT WIN32)
list(APPEND PYBIND_DEPS parallel_executor profiler) list(APPEND PYBIND_DEPS parallel_executor profiler)
list(APPEND PYBIND_SRCS recordio.cc)
endif(NOT WIN32) endif(NOT WIN32)
if(WITH_PYTHON) if(WITH_PYTHON)
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
......
...@@ -357,19 +357,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -357,19 +357,16 @@ All parameter, weight, gradient are variables in Paddle.
return self.GetMutable<platform::Communicator>(); return self.GetMutable<platform::Communicator>();
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
#endif
.def("get_reader", .def("get_reader",
[](Variable &self) -> framework::ReaderHolder * { [](Variable &self) -> framework::ReaderHolder * {
PADDLE_ENFORCE(self.IsType<framework::ReaderHolder>()); PADDLE_ENFORCE(self.IsType<framework::ReaderHolder>());
return self.GetMutable<framework::ReaderHolder>(); return self.GetMutable<framework::ReaderHolder>();
}, },
py::return_value_policy::reference) py::return_value_policy::reference);
#endif
;
#if !defined(_WIN32)
py::class_<framework::ReaderHolder>(m, "Reader", "") py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("reset", &framework::ReaderHolder::ResetAll); .def("reset", &framework::ReaderHolder::ResetAll);
#endif
using LoDTensorBlockingQueue = using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue; ::paddle::operators::reader::LoDTensorBlockingQueue;
...@@ -914,9 +911,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -914,9 +911,9 @@ All parameter, weight, gradient are variables in Paddle.
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
self.Run(fetch_tensors, fetched_var_name); self.Run(fetch_tensors, fetched_var_name);
}); });
#endif
BindRecordIOWriter(&m); BindRecordIOWriter(&m);
#endif
return m.ptr(); return m.ptr();
} }
} // namespace pybind } // namespace pybind
......
...@@ -169,6 +169,12 @@ __all__ = [ ...@@ -169,6 +169,12 @@ __all__ = [
'bilinear_tensor_product', 'bilinear_tensor_product',
] ]
# To avoid the api checker complains
if os.name == 'nt':
__all__.remove('dynamic_lstm')
__all__.remove('crf_decoding')
__all__.remove('roi_pool')
def fc(input, def fc(input,
size, size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册