提交 4ffa92d4 编写于 作者: P peizhilin

Merge branch 'develop' into windows/build

...@@ -38,7 +38,6 @@ if(NOT CMAKE_CROSSCOMPILING) ...@@ -38,7 +38,6 @@ if(NOT CMAKE_CROSSCOMPILING)
endif(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING)
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
include(simd) include(simd)
################################ Configurations ####################################### ################################ Configurations #######################################
......
...@@ -172,18 +172,21 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) ...@@ -172,18 +172,21 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here. # So, don't set these flags here.
if (NOT WIN32) # windows msvc2015 support c++11 natively. if (NOT WIN32) # windows msvc2015 support c++11 natively.
# -std=c++11 -fPIC not recoginize by msvc, -Xcompiler will be added by cmake. # -std=c++11 -fPIC not recoginize by msvc
list(APPEND CUDA_NVCC_FLAGS "-std=c++11") list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") # in cuda9, suppress cuda warning on eigen with "-w"
list(APPEND CUDA_NVCC_FLAGS "-w" "-Xcompiler -fPIC")
else(NOT WIN32)
list(APPEND CUDA_NVCC_FLAGS "-w" "-Xcompiler -fPIC" "-Xcompiler /w")
endif(NOT WIN32) endif(NOT WIN32)
if(WITH_FAST_MATH) if(WITH_FAST_MATH)
# Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html # Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
endif() endif(WITH_FAST_MATH)
# in cuda9, suppress cuda warning on eigen
list(APPEND CUDA_NVCC_FLAGS "-w")
# Set :expt-relaxed-constexpr to suppress Eigen warnings # Set :expt-relaxed-constexpr to suppress Eigen warnings
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
......
...@@ -53,7 +53,6 @@ find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a ...@@ -53,7 +53,6 @@ find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a
NO_DEFAULT_PATH NO_DEFAULT_PATH
DOC "Path to cuDNN library.") DOC "Path to cuDNN library.")
if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY) if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY)
set(CUDNN_FOUND ON) set(CUDNN_FOUND ON)
else() else()
...@@ -88,7 +87,7 @@ if(CUDNN_FOUND) ...@@ -88,7 +87,7 @@ if(CUDNN_FOUND)
if(NOT CUDNN_MAJOR_VERSION) if(NOT CUDNN_MAJOR_VERSION)
set(CUDNN_VERSION "???") set(CUDNN_VERSION "???")
else() else()
math(EXPR CUDNN_VERSION math(EXPR CUDNN_VERSION
"${CUDNN_MAJOR_VERSION} * 1000 + "${CUDNN_MAJOR_VERSION} * 1000 +
${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}") ${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}")
......
...@@ -33,42 +33,23 @@ MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}") ...@@ -33,42 +33,23 @@ MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}")
set(BOOST_SOURCES_DIR ${THIRD_PARTY_PATH}/boost) set(BOOST_SOURCES_DIR ${THIRD_PARTY_PATH}/boost)
set(BOOST_DOWNLOAD_DIR "${BOOST_SOURCES_DIR}/src/${BOOST_PROJECT}") set(BOOST_DOWNLOAD_DIR "${BOOST_SOURCES_DIR}/src/${BOOST_PROJECT}")
if (WIN32)
set(BOOST_INCLUDE_DIR "${BOOST_DOWNLOAD_DIR}" CACHE PATH "boost include directory." FORCE)
else(WIN32)
set(BOOST_INCLUDE_DIR "${BOOST_DOWNLOAD_DIR}/${BOOST_TAR}" CACHE PATH "boost include directory." FORCE)
endif (WIN32)
set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1)
set(BOOST_INCLUDE_DIR "${BOOST_DOWNLOAD_DIR}" CACHE PATH "boost include directory." FORCE)
set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1)
include_directories(${BOOST_INCLUDE_DIR}) include_directories(${BOOST_INCLUDE_DIR})
if (WIN32)
ExternalProject_Add( ExternalProject_Add(
${BOOST_PROJECT} ${BOOST_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DOWNLOAD_DIR ${BOOST_DOWNLOAD_DIR} DOWNLOAD_DIR ${BOOST_DOWNLOAD_DIR}
URL ${BOOST_URL} URL ${BOOST_URL}
DOWNLOAD_NO_PROGRESS 0 DOWNLOAD_NO_PROGRESS 0
PREFIX ${BOOST_SOURCES_DIR} PREFIX ${BOOST_SOURCES_DIR}
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
INSTALL_COMMAND "" INSTALL_COMMAND ""
UPDATE_COMMAND "" UPDATE_COMMAND ""
) )
else()
ExternalProject_Add(
${BOOST_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
DOWNLOAD_DIR ${BOOST_DOWNLOAD_DIR}
DOWNLOAD_COMMAND "wget --no-check-certificate ${BOOST_URL} -c -q -O ${BOOST_TAR}.tar.gz
&& tar zxf ${BOOST_TAR}.tar.gz"
DOWNLOAD_NO_PROGRESS 0
PREFIX ${BOOST_SOURCES_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
UPDATE_COMMAND ""
)
endif ()
if (${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32) if (${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32)
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c) set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c)
......
...@@ -51,6 +51,10 @@ ExternalProject_Add( ...@@ -51,6 +51,10 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES})
ADD_DEPENDENCIES(gflags extern_gflags)
IF(WIN32) IF(WIN32)
IF(NOT EXISTS "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib") IF(NOT EXISTS "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib")
add_custom_command(TARGET extern_gflags POST_BUILD add_custom_command(TARGET extern_gflags POST_BUILD
...@@ -58,9 +62,6 @@ IF(WIN32) ...@@ -58,9 +62,6 @@ IF(WIN32)
) )
ENDIF() ENDIF()
ENDIF(WIN32) ENDIF(WIN32)
ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES})
ADD_DEPENDENCIES(gflags extern_gflags)
LIST(APPEND external_project_dependencies gflags) LIST(APPEND external_project_dependencies gflags)
......
...@@ -52,6 +52,7 @@ IF(WITH_TESTING) ...@@ -52,6 +52,7 @@ IF(WITH_TESTING)
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_GMOCK=ON -DBUILD_GMOCK=ON
...@@ -71,6 +72,5 @@ IF(WITH_TESTING) ...@@ -71,6 +72,5 @@ IF(WITH_TESTING)
ADD_LIBRARY(gtest_main STATIC IMPORTED GLOBAL) ADD_LIBRARY(gtest_main STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gtest_main PROPERTY IMPORTED_LOCATION ${GTEST_MAIN_LIBRARIES}) SET_PROPERTY(TARGET gtest_main PROPERTY IMPORTED_LOCATION ${GTEST_MAIN_LIBRARIES})
ADD_DEPENDENCIES(gtest_main extern_gtest) ADD_DEPENDENCIES(gtest_main extern_gtest)
LIST(APPEND external_project_dependencies gtest gtest_main) LIST(APPEND external_project_dependencies gtest gtest_main)
ENDIF(WITH_TESTING) ENDIF(WITH_TESTING)
...@@ -149,6 +149,7 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) ...@@ -149,6 +149,7 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
# linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas) # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas)
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c) SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c)
FILE(WRITE ${dummyfile} "const char *dummy_cblas = \"${dummyfile}\";") FILE(WRITE ${dummyfile} "const char *dummy_cblas = \"${dummyfile}\";")
ADD_LIBRARY(cblas STATIC ${dummyfile}) ADD_LIBRARY(cblas STATIC ${dummyfile})
IF("${CBLAS_PROVIDER}" STREQUAL "MKLML") IF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
......
...@@ -144,11 +144,14 @@ set(GPU_COMMON_FLAGS ...@@ -144,11 +144,14 @@ set(GPU_COMMON_FLAGS
-Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array -Wno-error=array-bounds # Warnings in Eigen::array
) )
else(NOT WIN32) else(NOT WIN32)
set(COMMON_FLAGS set(COMMON_FLAGS
-fPIC
-fno-omit-frame-pointer
"/w") #disable all warnings. "/w") #disable all warnings.
set(GPU_COMMON_FLAGS set(GPU_COMMON_FLAGS
-fPIC
-fno-omit-frame-pointer
"/w") #disable all warnings "/w") #disable all warnings
endif(NOT WIN32) endif(NOT WIN32)
...@@ -164,8 +167,8 @@ endif(APPLE) ...@@ -164,8 +167,8 @@ endif(APPLE)
if(LINUX) if(LINUX)
set(GPU_COMMON_FLAGS set(GPU_COMMON_FLAGS
-Wall -Wall
-Wextra
-Werror -Werror
-Wextra
${GPU_COMMON_FLAGS}) ${GPU_COMMON_FLAGS})
endif(LINUX) endif(LINUX)
......
...@@ -238,6 +238,7 @@ function(cc_library TARGET_NAME) ...@@ -238,6 +238,7 @@ function(cc_library TARGET_NAME)
# add libxxx.lib prefix in windows # add libxxx.lib prefix in windows
set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}")
endif(WIN32) endif(WIN32)
if(cc_library_SRCS) if(cc_library_SRCS)
if(cc_library_SHARED OR cc_library_shared) # build *.so if(cc_library_SHARED OR cc_library_shared) # build *.so
add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) add_library(${TARGET_NAME} SHARED ${cc_library_SRCS})
...@@ -350,7 +351,11 @@ function(cc_test TARGET_NAME) ...@@ -350,7 +351,11 @@ function(cc_test TARGET_NAME)
set(multiValueArgs SRCS DEPS ARGS) set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS}) add_executable(${TARGET_NAME} ${cc_test_SRCS})
if(WIN32) # in windows deps. shlwapi library.
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog shlwapi)
else(WIN32)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
endif(WIN32)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_test(NAME ${TARGET_NAME} add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS} COMMAND ${TARGET_NAME} ${cc_test_ARGS}
...@@ -421,7 +426,11 @@ function(nv_test TARGET_NAME) ...@@ -421,7 +426,11 @@ function(nv_test TARGET_NAME)
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
if(WIN32)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog shlwapi)
else(WIN32)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
endif(WIN32)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL) if (nv_test_SERIAL)
......
...@@ -31,8 +31,7 @@ function(copy TARGET) ...@@ -31,8 +31,7 @@ function(copy TARGET)
foreach(index RANGE ${len}) foreach(index RANGE ${len})
list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_SRCS ${index} src)
list(GET copy_lib_DSTS ${index} dst) list(GET copy_lib_DSTS ${index} dst)
if (WIN32)
if (WIN32)
# windows cmd shell will not expand wildcard automatically. # windows cmd shell will not expand wildcard automatically.
# below expand the files,libs and copy them by rules. # below expand the files,libs and copy them by rules.
file(GLOB header_files ${src} "*.h") file(GLOB header_files ${src} "*.h")
...@@ -47,14 +46,14 @@ function(copy TARGET) ...@@ -47,14 +46,14 @@ function(copy TARGET)
COMMAND ${CMAKE_COMMAND} -E make_directory "${dst}" COMMAND ${CMAKE_COMMAND} -E make_directory "${dst}"
) )
foreach(src_file ${src_files}) foreach(src_file ${src_files})
add_custom_command(TARGET ${TARGET} PRE_BUILD add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E copy "${src_file}" "${dst}" COMMAND ${CMAKE_COMMAND} -E copy "${src_file}" "${dst}"
COMMENT "copying ${src_file} -> ${dst}") COMMENT "copying ${src_file} -> ${dst}")
endforeach() endforeach()
else() # not windows else(WIN32) # not windows
add_custom_command(TARGET ${TARGET} PRE_BUILD add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory "${dst}" COMMAND mkdir -p "${dst}"
COMMAND ${CMAKE_COMMAND} -E copy "${src_files}" "${dst}" COMMAND cp -r "${src}" "${dst}"
COMMENT "copying ${src} -> ${dst}") COMMENT "copying ${src} -> ${dst}")
endif(WIN32) endif(WIN32)
endforeach() endforeach()
......
...@@ -44,5 +44,5 @@ while ("${PADDLE_VERSION}" STREQUAL "") ...@@ -44,5 +44,5 @@ while ("${PADDLE_VERSION}" STREQUAL "")
endif() endif()
endwhile() endwhile()
add_definitions(-DPADDLE_VERSION=${PADDLE_VERSION}) add_definitions(-DPADDLE_VERSION="${PADDLE_VERSION}")
message(STATUS "Paddle version is ${PADDLE_VERSION}") message(STATUS "Paddle version is ${PADDLE_VERSION}")
../../v2/dev/contribute_to_paddle_cn.md
../../v2/dev/contribute_to_paddle_en.md
../../../howto/optimization/cpu_profiling_cn.md
../../../howto/optimization/host_memory_profiling_cn.md
../../../CONTRIBUTING.md ../../../CONTRIBUTING.md
\ No newline at end of file
...@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None ...@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
......
...@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle { ...@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_; std::vector<Scope*> param_scopes_;
Scope g_scope_; Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase* op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase*> vars_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
std::vector<p::Place> place_list_; std::vector<p::Place> place_list_;
bool use_gpu_; bool use_gpu_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle { ...@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle {
} }
void InitBroadcastOp(size_t input_scope_idx) { void InitBroadcastOp(size_t input_scope_idx) {
nodes_.clear();
for (size_t j = 0; j < place_list_.size(); ++j) { for (size_t j = 0; j < place_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle { ...@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n = nodes_.emplace_back(
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation); ir::CreateNodeForTest("node0", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get())); place_list_, nccl_ctxs_.get());
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get())); place_list_, nccl_ctxs_.get());
#else #else
op_handle_.reset( op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
new BroadcastOpHandle(n.get(), local_scopes_, place_list_)); place_list_);
#endif #endif
} }
std::unique_ptr<ir::Node> v = nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,
place_list_[input_scope_idx]); "input", place_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v2 = nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
dummy_var_handle->ClearGeneratedOp(); dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle); op_handle_->AddInput(dummy_var_handle);
...@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle { ...@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get());
} }
std::unique_ptr<ir::Node> v3 = nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle = VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", place_list_[j]); new VarHandle(nodes_.back().get(), 2, j, "out", place_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v4 = nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
out_dummy_var_handle->ClearGeneratedOp(); out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle); op_handle_->AddOutput(out_dummy_var_handle);
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_(strategy.num_threads_ + pool_(strategy.num_threads_ +
1), // add one more thread for generate op_deps 1), // add one more thread for generate op_deps
fetch_ctxs_(places) { fetch_ctxs_(places) {
auto &ops = graph_->Get<details::GraphOps>("ops"); for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
for (auto &op : ops) {
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op.get(), dep); op_deps_.emplace(op, dep);
if (dep == 0) { if (dep == 0) {
bootstrap_ops_.emplace_back(op.get()); bootstrap_ops_.emplace_back(op);
} }
} }
...@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle::framework::FeedFetchList fetches; paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
...@@ -110,7 +109,10 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -110,7 +109,10 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
complete_q->Pop(); complete_q->Pop();
} }
} }
exception_.ReThrow(); if (exception_.IsCaught()) {
ClearFetchOp(graph_.get(), &fetch_ops);
exception_.ReThrow();
}
} }
num_complete += num_comp; num_complete += num_comp;
} }
......
...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, ...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
offset_(offset), offset_(offset),
local_scopes_(local_scopes) {} local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() { FetchOpHandle::~FetchOpHandle() {}
for (auto *input_var : inputs_) {
input_var->RemoveOutput(this, this->Node());
}
}
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
......
...@@ -22,8 +22,10 @@ namespace details { ...@@ -22,8 +22,10 @@ namespace details {
struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
std::vector<std::string> out_varnames_; std::vector<std::string> out_varnames_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) { void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) {
nodes_.clear();
// initialize scope and var // initialize scope and var
for (size_t i = 0; i < place_list_.size(); ++i) { for (size_t i = 0; i < place_list_.size(); ++i) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
...@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { ...@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
} }
// create op handle node // create op handle node
std::unique_ptr<ir::Node> n = nodes_.emplace_back(
ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation); ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle( op_handle_ = new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get())); nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else #else
PADDLE_THROW("CUDA is not supported."); PADDLE_THROW("CUDA is not supported.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle( op_handle_ = new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get())); nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else #else
op_handle_.reset( op_handle_ = new FusedBroadcastOpHandle(nodes_.back().get(),
new FusedBroadcastOpHandle(n.get(), local_scopes_, place_list_)); local_scopes_, place_list_);
#endif #endif
} }
for (size_t i = 0; i < input_scope_idxes.size(); ++i) { for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
// add input var handle // add input var handle
std::unique_ptr<ir::Node> in_node = nodes_.emplace_back(
ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable); ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable));
VarHandle* in_var_handle = VarHandle* in_var_handle =
new VarHandle(in_node.get(), 1, input_scope_idxes[i], "in_var" + i, new VarHandle(nodes_.back().get(), 1, input_scope_idxes[i],
place_list_[input_scope_idxes[i]]); "in_var" + i, place_list_[input_scope_idxes[i]]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add output var handle // add output var handle
for (size_t j = 0; j < place_list_.size(); ++j) { for (size_t j = 0; j < place_list_.size(); ++j) {
std::unique_ptr<ir::Node> out_node = nodes_.emplace_back(
ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable); ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable));
VarHandle* out_var_handle = VarHandle* out_var_handle = new VarHandle(
new VarHandle(out_node.get(), 2, j, "out_var" + i, place_list_[j]); nodes_.back().get(), 2, j, "out_var" + i, place_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
......
...@@ -31,9 +31,10 @@ struct TestGatherOpHandle { ...@@ -31,9 +31,10 @@ struct TestGatherOpHandle {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_; std::vector<Scope*> param_scopes_;
Scope g_scope_; Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase* op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase*> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void WaitAll() { void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) { for (size_t j = 0; j < ctxs_.size(); ++j) {
...@@ -70,7 +71,7 @@ struct TestGatherOpHandle { ...@@ -70,7 +71,7 @@ struct TestGatherOpHandle {
} }
void InitGatherOp(size_t input_scope_idx) { void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes; nodes_.clear();
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -82,44 +83,45 @@ struct TestGatherOpHandle { ...@@ -82,44 +83,45 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_.reset( op_handle_ =
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_);
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release());
auto* in_var_handle = auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes_.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
in_dummy_var_handle->ClearGeneratedOp(); in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, auto* out_var_handle =
"out", gpu_list_[input_scope_idx]); new VarHandle(nodes_.back().get(), 2, input_scope_idx, "out",
gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
op_handle_->AddOutput(dummy_var_handle); op_handle_->AddOutput(dummy_var_handle);
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h" #include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( ...@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl( std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const { std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops); OpGraphView graph_view(all_ops);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue; if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free = bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view); IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair);
} }
} }
} }
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) { for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get()); insert_pending_var(var);
} }
for (auto &op : graph->Get<GraphOps>(kGraphOps)) { for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); pending_ops.insert({op, op->NoDupInputSize()});
} }
} }
...@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
REGISTER_PASS(multi_devices_check_pass, REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker) paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars) .RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars) .RequireGraphAttr(paddle::framework::details::kGraphDepVars);
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
...@@ -34,7 +34,14 @@ ...@@ -34,7 +34,14 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
namespace { namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
void PolishGraphToSupportDataHazards(ir::Graph *graph) { void PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
...@@ -92,7 +99,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, ...@@ -92,7 +99,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
} }
var_holder.emplace_back(var); var_holder.emplace_back(var);
} else { } else {
var = var_holder.rbegin()->get(); var = *var_holder.rbegin();
} }
return var; return var;
} }
...@@ -154,7 +161,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ...@@ -154,7 +161,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node, ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -303,7 +310,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -303,7 +310,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result.Set(kGraphVars, new GraphVars(places_.size())); result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set(kGraphDepVars, new GraphDepVars); result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps); result.Set(kGraphOps, new GraphOps);
result.Set(kShardedVarDevice, new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// related op in the place 0 // related op in the place 0
...@@ -317,11 +323,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -317,11 +323,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
bool is_forwarding = true; bool is_forwarding = true;
bool is_dist_train = false; bool is_dist_train = false;
std::unordered_map<std::string, int> sharded_var_device;
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (boost::get<int>( if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(&result, node); int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place."); "Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") { if (node->Op()->Type() == "recv") {
...@@ -337,7 +345,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -337,7 +345,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} else if (boost::get<int>(node->Op()->GetAttr( } else if (boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) == OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kDist)) { static_cast<int>(OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node); int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name); bcast_var_name_set[op_dev_id].emplace(origin_param_name);
...@@ -356,12 +364,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -356,12 +364,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(result, node); int op_dev_id = GetOpDeviceID(result, node, sharded_var_device);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
graph->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device.emplace(n->Name(), op_dev_id);
.emplace(n->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
...@@ -398,8 +405,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -398,8 +405,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name}); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device.emplace(g_name, cur_device_id);
.emplace(g_name, cur_device_id);
if (!is_dist_train) { if (!is_dist_train) {
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
} }
...@@ -458,7 +464,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -458,7 +464,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
PADDLE_ENFORCE(!ir::HasCircle(result)); result.Erase<GraphOps>(kGraphOps);
return graph; return graph;
} }
...@@ -498,7 +504,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, ...@@ -498,7 +504,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
auto *in = auto *in =
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -535,7 +541,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( ...@@ -535,7 +541,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
for (auto &p_name : bcast_varnames[dev_id]) { for (auto &p_name : bcast_varnames[dev_id]) {
auto *in = auto *in =
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) { for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
auto &p = places_[out_dev_id]; auto &p = places_[out_dev_id];
...@@ -571,7 +577,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -571,7 +577,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -579,7 +585,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -579,7 +585,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad);
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
...@@ -600,14 +606,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -600,14 +606,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) { for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back());
auto var = new VarHandle( auto var = new VarHandle(
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable), result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p); vars.size(), i, d_name, p);
...@@ -617,8 +623,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -617,8 +623,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
} }
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetOpDeviceID(
ir::Node *node) const { const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
...@@ -631,15 +638,15 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, ...@@ -631,15 +638,15 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(graph, param_grad[1]); int dev_id = GetVarDeviceID(graph, param_grad[1], sharded_var_device);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]); node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id; return dev_id;
} }
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetVarDeviceID(
const std::string &varname) const { const ir::Graph &graph, const std::string &varname,
auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice); const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname); auto got = sharded_var_device.find(varname);
return got == sharded_var_device.end() ? -1 : got->second; return got == sharded_var_device.end() ? -1 : got->second;
} }
...@@ -690,7 +697,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -690,7 +697,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -698,7 +705,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -698,7 +705,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad);
} }
auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
auto var = auto var =
...@@ -709,8 +716,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -709,8 +716,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var; return var;
} }
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateDistTrainOp(
ir::Node *node) const { ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1; int op_dev_id = -1;
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
...@@ -725,23 +733,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -725,23 +733,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type() == "split_selected_rows" || node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") { node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else if (node->Op()->Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else { } else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
...@@ -759,14 +766,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -759,14 +766,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
} }
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (ir::Node *input : node->inputs) { for (ir::Node *input : node->inputs) {
VarHandle *var = nullptr; VarHandle *var = nullptr;
for (int place_offset = 0; place_offset < num_places; ++place_offset) { for (int place_offset = 0; place_offset < num_places; ++place_offset) {
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset]; auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[input->Name()]; auto &var_holder = var_holders[input->Name()];
if (!var_holder.empty()) { if (!var_holder.empty()) {
var = var_holder.rbegin()->get(); var = *var_holder.rbegin();
op_handle->AddInput(var); op_handle->AddInput(var);
} }
} }
...@@ -774,12 +781,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { ...@@ -774,12 +781,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateRPCOp(
ir::Node *node) const { ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name()); op_dev_id =
GetVarDeviceID(*result, node->inputs[0]->Name(), *sharded_var_device);
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix."); "This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
...@@ -797,11 +806,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -797,11 +806,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
VLOG(10) << "send grad " << input_var_names[0] << " origin " VLOG(10) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id; << send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(send_param_grad[1], op_dev_id);
.emplace(send_param_grad[1], op_dev_id);
} }
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
...@@ -811,7 +818,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -811,7 +818,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) { if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]); op_dev_id =
GetVarDeviceID(*result, recv_param_grad[1], *sharded_var_device);
VLOG(10) << "recv param " << recv_param_grad[0] VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1] << " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id; << " place: " << op_dev_id;
...@@ -819,8 +827,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -819,8 +827,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
op_dev_id = GetAppropriateDeviceID(output_var_names); op_dev_id = GetAppropriateDeviceID(output_var_names);
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else { } else {
// send_barrier, fetch_barrier will run on place 0; // send_barrier, fetch_barrier will run on place 0;
...@@ -839,7 +846,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -839,7 +846,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from // send_barrier, recv, fetch_barrier's inputs are deps var, get them from
// all places // all places
auto p = places_[op_dev_id]; auto p = places_[op_dev_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -847,7 +854,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -847,7 +854,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
for (ir::Node *output : node->outputs) { for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id; int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(*result, output->Name()); outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1); PADDLE_ENFORCE_NE(outvar_dev_id, -1);
} }
p = places_[outvar_dev_id]; p = places_[outvar_dev_id];
......
...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const; int GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateRPCOp(
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const; const std::vector<ir::Node *> &nodes) const;
...@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; int dev_id) const;
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const; int GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>(kGraphOps)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
......
...@@ -35,23 +35,14 @@ namespace details { ...@@ -35,23 +35,14 @@ namespace details {
// The outside vector is the device vector. Each element of this vector is a // The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name, // map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the // will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles. // `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector< typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars"; const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<VarHandleBase*> GraphDepVars;
const char kGraphDepVars[] = "dep_vars"; const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,19 +20,16 @@ namespace paddle { ...@@ -20,19 +20,16 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
OpGraphView::OpGraphView( OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) { void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
for (auto &op : ops) { for (auto &op : ops) {
preceding_ops_[op.get()]; preceding_ops_[op];
pending_ops_[op.get()]; pending_ops_[op];
for (auto &var : op->Outputs()) { for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) { for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get()); preceding_ops_[pending_op].insert(op);
pending_ops_[op.get()].insert(pending_op); pending_ops_[op].insert(pending_op);
} }
} }
} }
...@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) { ...@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
"There are duplicate ops in graph."); "There are duplicate ops in graph.");
} }
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const { std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret; std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) { for (auto &pair : preceding_ops_) {
...@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const { ...@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
op == nullptr ? "nullptr" : op->DebugString()); op == nullptr ? "nullptr" : op->DebugString());
} }
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps( const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const { OpHandleBase *op) const {
EnforceHasOp(op); EnforceHasOp(op);
......
...@@ -26,21 +26,16 @@ namespace details { ...@@ -26,21 +26,16 @@ namespace details {
class OpGraphView { class OpGraphView {
public: public:
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops); explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
size_t OpNumber() const;
std::unordered_set<OpHandleBase *> AllOps() const; std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PrecedingOps(
OpHandleBase *op) const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const; const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const; bool HasOp(OpHandleBase *op) const;
private: private:
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops); void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const; void EnforceHasOp(OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>> std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
......
...@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; ...@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// It's responsible for populating necessary fields of ir::Node. // It's responsible for populating necessary fields of ir::Node.
class OpHandleBase { class OpHandleBase {
public: public:
explicit OpHandleBase(ir::Node *node) : node_(node) {} // Owned by `node`. No need to be deleted explicitly.
explicit OpHandleBase(ir::Node *node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~OpHandleBase(); virtual ~OpHandleBase();
......
...@@ -30,8 +30,8 @@ struct TestReduceOpHandle { ...@@ -30,8 +30,8 @@ struct TestReduceOpHandle {
Scope g_scope_; Scope g_scope_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<Scope *> param_scopes_; std::vector<Scope *> param_scopes_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase *op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase *> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_; std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h" #include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -71,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -71,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables // Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops // in computation ops
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>> std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map; compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&]( auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op; std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr || if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace())) !platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op; return var_names_in_op;
...@@ -121,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -121,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
}; };
auto update_ref_cnts_from_non_compute_op = [&]( auto update_ref_cnts_from_non_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) { if (dynamic_cast<ComputationOpHandle *>(op) != nullptr) return;
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) { for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base); auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
...@@ -151,21 +150,21 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -151,21 +150,21 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get()); AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[next_compute_op] = ref_cnt_handle;
} }
} }
} }
} }
}; };
auto &all_ops = graph->Get<GraphOps>(kGraphOps); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue; if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(), in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end()); out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace()); auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node = ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
...@@ -173,7 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -173,7 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node, compute_op->GetScope(), place, in_var_names, ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(compute_op, ref_cnt_handle, graph.get()); AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[compute_op] = ref_cnt_handle;
} }
for (auto &op : all_ops) { for (auto &op : all_ops) {
...@@ -181,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -181,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
update_ref_cnts_from_non_compute_op(op, op->Outputs()); update_ref_cnts_from_non_compute_op(op, op->Outputs());
} }
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops; std::vector<OpHandleBase *> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) { for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op)); new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); auto it = compute_ref_cnt_map.find(new_all_ops.back());
if (it != compute_ref_cnt_map.end()) { if (it != compute_ref_cnt_map.end()) {
// Add LeafNode to ReferenceCountOpHandle // Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
......
...@@ -19,14 +19,16 @@ namespace framework { ...@@ -19,14 +19,16 @@ namespace framework {
namespace details { namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops) {
if (fetch_ops->empty()) return; if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) { for (auto& op : *fetch_ops) {
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
for (auto& in_var : op->Inputs()) {
in_var->RemoveOutput(op, op->Node());
}
graph->RemoveNode(op->Node()); graph->RemoveNode(op->Node());
} }
fetch_ops->clear(); fetch_ops->clear();
......
...@@ -38,8 +38,7 @@ class SSAGraphExecutor { ...@@ -38,8 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops);
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -51,25 +52,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -51,25 +52,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, ready_vars.get(), version_pair.get()); InsertPendingVar(&pending_vars, ready_vars.get(), version_pair);
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var.get()); InsertPendingVar(&pending_vars, ready_vars.get(), var);
} }
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
InsertPendingOp(&pending_ops, op.get()); InsertPendingOp(&pending_ops, op);
} }
} }
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
...@@ -109,6 +110,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -109,6 +110,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
ClearFetchOp(graph_.get(), &fetch_ops);
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} else { } else {
continue; continue;
...@@ -140,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -140,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) { BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
...@@ -151,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -151,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
......
...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
BlockingQueue<VarHandleBase *> *ready_vars, BlockingQueue<VarHandleBase *> *ready_vars,
VarHandleBase *var) const; VarHandleBase *var) const;
void InsertFetchOps( void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
const std::vector<std::string> &fetch_tensors, std::vector<FetchOpHandle *> *fetch_ops,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_set<VarHandleBase *> *pending_vars,
std::unordered_set<VarHandleBase *> *pending_vars, BlockingQueue<VarHandleBase *> *ready_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data); FeedFetchList *fetch_data);
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
......
...@@ -20,6 +20,8 @@ namespace details { ...@@ -20,6 +20,8 @@ namespace details {
VarHandleBase::~VarHandleBase() {} VarHandleBase::~VarHandleBase() {}
VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const { std::string VarHandle::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << name_ << ":" << place_; ss << name_ << ":" << place_;
...@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const { ...@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const {
} }
std::string DummyVarHandle::DebugString() const { return node_->Name(); } std::string DummyVarHandle::DebugString() const { return node_->Name(); }
DummyVarHandle::~DummyVarHandle() {
VLOG(4) << "deleting dummy var handle " << DebugString();
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -35,7 +35,10 @@ class OpHandleBase; ...@@ -35,7 +35,10 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e. // A variable can only be generated by a single operator. i.e.
// This is a single assignment graph. // This is a single assignment graph.
struct VarHandleBase { struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {} // Owned by `node`. No need to be deleted explicitly.
explicit VarHandleBase(ir::Node* node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~VarHandleBase(); virtual ~VarHandleBase();
...@@ -94,6 +97,8 @@ struct VarHandleBase { ...@@ -94,6 +97,8 @@ struct VarHandleBase {
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {} explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~VarHandle();
std::string DebugString() const override; std::string DebugString() const override;
VarHandle(ir::Node* node, size_t version, size_t scope_index, VarHandle(ir::Node* node, size_t version, size_t scope_index,
...@@ -121,6 +126,8 @@ struct VarHandle : public VarHandleBase { ...@@ -121,6 +126,8 @@ struct VarHandle : public VarHandleBase {
struct DummyVarHandle : public VarHandleBase { struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {} explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~DummyVarHandle();
std::string DebugString() const override; std::string DebugString() const override;
}; };
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
...@@ -46,6 +48,7 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { ...@@ -46,6 +48,7 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext"; VLOG(5) << "destroy ExecutorPrepareContext";
} }
#ifndef _WIN32
template <typename RefCntMap> template <typename RefCntMap>
static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
GarbageCollector<Tensor>* gc, GarbageCollector<Tensor>* gc,
...@@ -80,6 +83,7 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, ...@@ -80,6 +83,7 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
gc->Add(erase_tensors); gc->Add(erase_tensors);
} }
} }
#endif
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
...@@ -367,6 +371,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -367,6 +371,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
CreateVariables(ctx->prog_, local_scope, ctx->block_id_); CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
} }
#ifndef _WIN32
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc; std::unique_ptr<GarbageCollector<Tensor>> gc;
// WhileOp would set keep_kids to false // WhileOp would set keep_kids to false
...@@ -408,6 +413,16 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -408,6 +413,16 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} else { } else {
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
} }
#else // WIN32
for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_);
if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_);
}
}
platform::DeviceContextPool::Instance().Get(place_)->Wait();
#endif // NOT WIN32
if (local_scope != scope) { if (local_scope != scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
......
...@@ -17,12 +17,14 @@ limitations under the License. */ ...@@ -17,12 +17,14 @@ limitations under the License. */
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#ifndef _WIN32
#include "paddle/fluid/framework/garbage_collector.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") ...@@ -53,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass) cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
......
...@@ -102,6 +102,15 @@ class Graph { ...@@ -102,6 +102,15 @@ class Graph {
attr_dels_[attr_name] = []() {}; attr_dels_[attr_name] = []() {};
} }
template <typename AttrType>
void Erase(const std::string &attr_name) {
PADDLE_ENFORCE(attrs_.count(attr_name) != 0, "%s not set in the graph",
attr_name);
attr_dels_[attr_name]();
attrs_.erase(attr_name);
attr_dels_.erase(attr_name);
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; } const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
......
...@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph); ...@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph); const Graph &graph);
template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
}
return ret;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,7 +17,12 @@ limitations under the License. */ ...@@ -17,7 +17,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// msvc15 don't support constexpr in correct way.
#if !defined(_WIN32)
constexpr char Node::kControlDepVarName[]; constexpr char Node::kControlDepVarName[];
#else
const char Node::kControlDepVarName[] = "__control_var";
#endif
int Node::count_ = 0; int Node::count_ = 0;
std::unique_ptr<Node> CreateNodeForTest(const std::string& name, std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
......
...@@ -15,7 +15,10 @@ limitations under the License. */ ...@@ -15,7 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <typeindex>
#include <typeinfo>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -24,11 +27,39 @@ namespace paddle { ...@@ -24,11 +27,39 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// Node should normally created by Graph::CreateXXXNode(). // Node should only created by Graph::CreateXXXNode().
// 1. Every Node should be part of a graph. No dangling Node exists.
// 2. Node only contains members necessary for building graph structure.
// It doesn't contain other unrelated members, such as device, etc.
//
// Sometimes, for specific usages, Node needs to have additional members,
// such as device_placement, version in order to be executed. It is suggested
// to use composition pattern.
//
// class RunnableOp {
// RunnableOp(ir::Node* n) : n_(n) { n_.WrappedBy(this); }
//
// int any_thing_;
// }
//
// RunnableOp is owned by the ir::Node that composes it. In other words.
// ir::Node will be responsible for deleting RunnableOp, say, when ir::Node
// is deleted from the graph.
class Node { class Node {
public: public:
virtual ~Node() {
if (!wrapper_.empty()) {
VLOG(4) << "ir::Node deleting a wrapper node " << Name();
wrapper_deleter_();
}
}
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static constexpr const char kControlDepVarName[] = "__control_var"; #if !defined(_WIN32) // msvc not support constexpr correctly.
static constexpr char kControlDepVarName[] = "__control_var";
#else
static const char kControlDepVarName[];
#endif
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
...@@ -44,6 +75,29 @@ class Node { ...@@ -44,6 +75,29 @@ class Node {
return op_desc_.get(); return op_desc_.get();
} }
// Set the `wrapper` that wraps the Node. `wrapper` is owned by Node.
template <typename T>
void WrappedBy(T* wrapper) {
if (!wrapper_.empty()) {
wrapper_deleter_();
}
wrapper_ = wrapper;
wrapper_deleter_ = [wrapper]() { delete wrapper; };
wrapper_type_ = std::type_index(typeid(T));
}
// Return a reference to the `wrapper`.
template <typename T>
T& Wrapper() {
return *boost::any_cast<T*>(wrapper_);
}
// Test if the Node is wrapped by type T.
template <typename T>
bool IsWrappedBy() {
return std::type_index(typeid(T)) == wrapper_type_;
}
// Please don't use this API! // Please don't use this API!
int id() const { return id_; } int id() const { return id_; }
...@@ -95,6 +149,11 @@ class Node { ...@@ -95,6 +149,11 @@ class Node {
static int count_; static int count_;
// Please don't use this API or make this public. // Please don't use this API or make this public.
static void ResetId() { count_ = 0; } static void ResetId() { count_ = 0; }
boost::any wrapper_;
std::function<void(void)> wrapper_deleter_;
std::type_index wrapper_type_ = std::type_index(typeid(void));
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
}; };
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class RunnableOp {
public:
RunnableOp(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
class RunnableOp2 {
public:
RunnableOp2(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp2() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
TEST(NodeTest, Basic) {
bool alive1 = true;
bool alive2 = true;
std::unique_ptr<Node> n1(CreateNodeForTest("n1", Node::Type::kVariable));
std::unique_ptr<Node> n2(CreateNodeForTest("n2", Node::Type::kVariable));
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp2>());
new RunnableOp(n1.get(), &alive1);
new RunnableOp2(n2.get(), &alive2);
EXPECT_TRUE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_TRUE(n2->IsWrappedBy<RunnableOp2>());
EXPECT_TRUE(alive1);
EXPECT_TRUE(alive2);
n1.reset(nullptr);
n2.reset(nullptr);
EXPECT_FALSE(alive1);
EXPECT_FALSE(alive2);
}
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
...@@ -195,6 +196,7 @@ struct PassRegistrar : public Registrar { ...@@ -195,6 +196,7 @@ struct PassRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \ __test_global_namespace_##uniq_name##__>::value, \
msg) msg)
#if !defined(_WIN32)
// Register a new pass that can be applied on the IR. // Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \ #define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
...@@ -217,7 +219,30 @@ struct PassRegistrar : public Registrar { ...@@ -217,7 +219,30 @@ struct PassRegistrar : public Registrar {
extern int TouchPassRegistrar_##pass_type(); \ extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ __UNUSED__() = \ static int use_pass_itself_##pass_type##_ __UNUSED__() = \
TouchPassRegistrar_##pass_type() TouchPassRegistrar_##pass_type()
#else
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> UNUSED( \
&__pass_tmp_registrar_##pass_type##__) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int UNUSED(use_pass_itself_##pass_type##_) = \
TouchPassRegistrar_##pass_type()
#endif // !_WIN32
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,11 @@ limitations under the License. */ ...@@ -20,6 +20,11 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
#if defined(_WIN32)
#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h
#define GOOGLE_GLOG_DLL_DECL
#endif
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
......
...@@ -17,6 +17,10 @@ cc_library(paddle_fluid_api ...@@ -17,6 +17,10 @@ cc_library(paddle_fluid_api
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES) get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES)
get_property(fluid_third_partys GLOBAL PROPERTY FLUID_THRID_PARTYS)
if (WIN32)
list(APPEND fluid_third_partys gflags glog protobuf cblas)
endif(WIN32)
# paddle_fluid_origin exclude inference api interface # paddle_fluid_origin exclude inference api interface
if(WIN32) if(WIN32)
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
...@@ -102,7 +103,6 @@ struct Argument { ...@@ -102,7 +103,6 @@ struct Argument {
std::unordered_map<std::string, std::function<void()>> attr_deleters_; std::unordered_map<std::string, std::function<void()>> attr_deleters_;
}; };
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ #define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \
if (UNLIKELY(!(field__))) { \ if (UNLIKELY(!(field__))) { \
LOG(ERROR) << "field " << #field__ << " should be set."; \ LOG(ERROR) << "field " << #field__ << " should be set."; \
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <sys/stat.h>
#include <cstdio> #include <cstdio>
#include <fstream> #include <fstream>
#include <string> #include <string>
...@@ -26,6 +25,7 @@ limitations under the License. */ ...@@ -26,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -124,24 +124,6 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) { ...@@ -124,24 +124,6 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) {
return *var->GetMutable<T>(); return *var->GetMutable<T>();
} }
static void ExecShellCommand(const std::string &cmd, std::string *message) {
char buffer[128];
#if !defined(_WIN32)
std::shared_ptr<FILE> pipe(popen(cmd.c_str(), "r"), pclose);
#else
std::shared_ptr<FILE> pipe(_popen(cmd.c_str(), "r"), _pclose);
#endif // _WIN32
if (!pipe) {
LOG(ERROR) << "error running command: " << cmd;
return;
}
while (!feof(pipe.get())) {
if (fgets(buffer, 128, pipe.get()) != nullptr) {
*message += buffer;
}
}
}
static framework::proto::ProgramDesc LoadProgramDesc( static framework::proto::ProgramDesc LoadProgramDesc(
const std::string &model_path) { const std::string &model_path) {
std::ifstream fin(model_path, std::ios::in | std::ios::binary); std::ifstream fin(model_path, std::ios::in | std::ios::binary);
...@@ -163,16 +145,6 @@ static bool FileExists(const std::string &filepath) { ...@@ -163,16 +145,6 @@ static bool FileExists(const std::string &filepath) {
return exists; return exists;
} }
static bool PathExists(const std::string &path) {
struct stat statbuf;
if (stat(path.c_str(), &statbuf) != -1) {
if (S_ISDIR(statbuf.st_mode)) {
return true;
}
}
return false;
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
......
...@@ -24,6 +24,7 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -24,6 +24,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
endif() endif()
cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope) cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope)
cc_library(helper SRCS helper.cc DEPS reset_tensor_array lod_tensor scope)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS reset_tensor_array lod_tensor scope) cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS reset_tensor_array lod_tensor scope)
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor) cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor)
cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api) cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api)
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle_inference_api.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <fstream>
#include <map> #include <map>
#include <set> #include <set>
#include <sstream> #include <sstream>
...@@ -24,6 +25,7 @@ limitations under the License. */ ...@@ -24,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/timer.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -31,16 +33,6 @@ DEFINE_bool(profile, false, "Turn on profiler for fluid"); ...@@ -31,16 +33,6 @@ DEFINE_bool(profile, false, "Turn on profiler for fluid");
DECLARE_int32(paddle_num_threads); DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace {
using paddle::inference::Timer;
template <class T>
std::string num2str(T a) {
std::stringstream istr;
istr << a;
return istr.str();
}
} // namespace
void NativePaddlePredictor::PrepareFeedFetch() { void NativePaddlePredictor::PrepareFeedFetch() {
for (auto *op : inference_program_->Block(0).AllOps()) { for (auto *op : inference_program_->Block(0).AllOps()) {
...@@ -63,7 +55,6 @@ void NativePaddlePredictor::PrepareFeedFetch() { ...@@ -63,7 +55,6 @@ void NativePaddlePredictor::PrepareFeedFetch() {
bool NativePaddlePredictor::Init( bool NativePaddlePredictor::Init(
std::shared_ptr<framework::Scope> parent_scope) { std::shared_ptr<framework::Scope> parent_scope) {
VLOG(3) << "Predictor::init()";
#if !defined(_WIN32) #if !defined(_WIN32)
if (FLAGS_profile) { if (FLAGS_profile) {
LOG(WARNING) << "Profiler is actived, might affect the performance"; LOG(WARNING) << "Profiler is actived, might affect the performance";
...@@ -91,21 +82,21 @@ bool NativePaddlePredictor::Init( ...@@ -91,21 +82,21 @@ bool NativePaddlePredictor::Init(
paddle::framework::InitDevices(false); paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope()); scope_.reset(new paddle::framework::Scope());
} }
executor_.reset(new paddle::framework::Executor(place_)); executor_.reset(new paddle::framework::Executor(place_));
// Initialize the inference program // Initialize the inference program
if (!config_.model_dir.empty()) { if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in // Parameters are saved in separate files sited in
// the specified `dirname`. // the specified `dirname`.
inference_program_ = paddle::inference::Load(executor_.get(), scope_.get(), inference_program_ = paddle::inference::Load(executor_.get(), scope_.get(),
config_.model_dir); config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) { } else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file. // All parameters are saved in a single file.
// The file names should be consistent with that used // The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`. // in Python API `fluid.io.save_inference_model`.
inference_program_ = paddle::inference::Load( inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.prog_file, config_.param_file); executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
} else { } else {
LOG(ERROR) << "fail to load inference model from " << config_.model_dir; LOG(ERROR) << "fail to load inference model from " << config_.model_dir;
return false; return false;
...@@ -135,7 +126,7 @@ NativePaddlePredictor::~NativePaddlePredictor() { ...@@ -135,7 +126,7 @@ NativePaddlePredictor::~NativePaddlePredictor() {
bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data, std::vector<PaddleTensor> *output_data,
int batch_size) { int batch_size) {
VLOG(3) << "Predictor::predict"; using Timer = paddle::inference::Timer;
Timer timer; Timer timer;
timer.tic(); timer.tic();
// set feed variable // set feed variable
...@@ -147,11 +138,9 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -147,11 +138,9 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
} }
// Run the inference program // Run the inference program
// if share variables, we need not create variables // if share variables, we need not create variables
VLOG(4) << "Run prepared context";
executor_->RunPreparedContext(ctx_.get(), scope, executor_->RunPreparedContext(ctx_.get(), scope,
false, /* don't create local scope each time*/ false, /* don't create local scope each time*/
false /* don't create variable each time */); false /* don't create variable each time */);
VLOG(4) << "Finish prepared context";
// get fetch variable // get fetch variable
if (!GetFetch(output_data, scope)) { if (!GetFetch(output_data, scope)) {
LOG(ERROR) << "fail to get fetches"; LOG(ERROR) << "fail to get fetches";
...@@ -166,7 +155,6 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -166,7 +155,6 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
} }
std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() { std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
VLOG(3) << "Predictor::clone";
std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_)); std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_));
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init(scope_)) { if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init(scope_)) {
...@@ -184,7 +172,6 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() { ...@@ -184,7 +172,6 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs, bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
framework::Scope *scope) { framework::Scope *scope) {
VLOG(3) << "Predictor::set_feed";
if (inputs.size() != feeds_.size()) { if (inputs.size() != feeds_.size()) {
LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get " LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get "
<< inputs.size(); << inputs.size();
...@@ -244,7 +231,6 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch, ...@@ -244,7 +231,6 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch,
bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs, bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
framework::Scope *scope) { framework::Scope *scope) {
VLOG(3) << "Predictor::get_fetch";
outputs->resize(fetchs_.size()); outputs->resize(fetchs_.size());
for (size_t i = 0; i < fetchs_.size(); ++i) { for (size_t i = 0; i < fetchs_.size(); ++i) {
int idx = boost::get<int>(fetchs_[i]->GetAttr("col")); int idx = boost::get<int>(fetchs_[i]->GetAttr("col"));
...@@ -269,25 +255,22 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -269,25 +255,22 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
template <> template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
NativeConfig, PaddleEngineKind::kNative>(const NativeConfig &config) { NativeConfig, PaddleEngineKind::kNative>(const NativeConfig &config) {
VLOG(3) << "create NativePaddlePredictor";
if (config.use_gpu) { if (config.use_gpu) {
// 1. GPU memeroy // 1. GPU memeroy
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
config.fraction_of_gpu_memory, 0.f, config.fraction_of_gpu_memory, 0.f,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]"); "fraction_of_gpu_memory in the config should be set to range (0.,1.]");
PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device); PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device);
std::vector<std::string> flags; std::vector<std::string> flags;
if (config.fraction_of_gpu_memory >= 0.0f || if (config.fraction_of_gpu_memory >= 0.0f ||
config.fraction_of_gpu_memory <= 0.95f) { config.fraction_of_gpu_memory <= 0.95f) {
flags.push_back("dummpy"); flags.push_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" + std::string flag = "--fraction_of_gpu_memory_to_use=" +
num2str<float>(config.fraction_of_gpu_memory); std::to_string(config.fraction_of_gpu_memory);
flags.push_back(flag); flags.push_back(flag);
VLOG(3) << "set flag: " << flag;
framework::InitGflags(flags); framework::InitGflags(flags);
} }
} }
std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config)); std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config));
if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init(nullptr)) { if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init(nullptr)) {
return nullptr; return nullptr;
......
...@@ -31,10 +31,10 @@ limitations under the License. */ ...@@ -31,10 +31,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle_inference_api.h" // NOLINT
namespace paddle { namespace paddle {
......
...@@ -6,13 +6,13 @@ option(WITH_STATIC_LIB "Compile demo with static/shared library, default use sta ...@@ -6,13 +6,13 @@ option(WITH_STATIC_LIB "Compile demo with static/shared library, default use sta
option(USE_TENSORRT "Compile demo with TensorRT." OFF) option(USE_TENSORRT "Compile demo with TensorRT." OFF)
macro(safe_set_static_flag) macro(safe_set_static_flag)
foreach(flag_var foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD") if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD") endif(${flag_var} MATCHES "/MD")
endforeach(flag_var) endforeach(flag_var)
endmacro() endmacro()
if (WIN32) if (WIN32)
...@@ -37,26 +37,25 @@ if(NOT DEFINED DEMO_NAME) ...@@ -37,26 +37,25 @@ if(NOT DEFINED DEMO_NAME)
endif() endif()
if(WITH_GPU) if(WITH_GPU) # default gpu path
if(NOT WIN32) if(NOT WIN32)
set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library") set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library")
else() else()
if(CUDA_LIB STREQUAL "") if(CUDA_LIB STREQUAL "")
set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64")
endif() endif()
endif(NOT WIN32) endif(NOT WIN32)
endif() endif()
include_directories("D:/Paddle/")
include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}")
include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include")
include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include")
include_directories("${PADDLE_LIB}/third_party/install/gflags/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include")
include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") include_directories("${PADDLE_LIB}/third_party/install/xxhash/include")
if (NOT WIN32) if (NOT WIN32)
include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappy/include")
include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include")
include_directories("${PADDLE_LIB}/third_party/install/zlib/include") include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
endif(NOT WIN32) endif(NOT WIN32)
include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/boost")
...@@ -64,15 +63,15 @@ include_directories("${PADDLE_LIB}/third_party/eigen3") ...@@ -64,15 +63,15 @@ include_directories("${PADDLE_LIB}/third_party/eigen3")
if (NOT WIN32) if (NOT WIN32)
if (USE_TENSORRT AND WITH_GPU) if (USE_TENSORRT AND WITH_GPU)
include_directories("${TENSORRT_INCLUDE_DIR}") include_directories("${TENSORRT_INCLUDE_DIR}")
link_directories("${TENSORRT_LIB_DIR}") link_directories("${TENSORRT_LIB_DIR}")
endif() endif()
endif(NOT WIN32) endif(NOT WIN32)
if (NOT WIN32) if (NOT WIN32)
link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") link_directories("${PADDLE_LIB}/third_party/install/snappy/lib")
link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib")
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
endif(NOT WIN32) endif(NOT WIN32)
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
...@@ -86,7 +85,7 @@ add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) ...@@ -86,7 +85,7 @@ add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
if(WITH_MKL) if(WITH_MKL)
include_directories("${PADDLE_LIB}/third_party/install/mklml/include") include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn")
if(EXISTS ${MKLDNN_PATH}) if(EXISTS ${MKLDNN_PATH})
include_directories("${MKLDNN_PATH}/include") include_directories("${MKLDNN_PATH}/include")
...@@ -99,25 +98,25 @@ endif() ...@@ -99,25 +98,25 @@ endif()
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a # Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a
if(WITH_STATIC_LIB) if(WITH_STATIC_LIB)
set(DEPS set(DEPS
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
else() else()
set(DEPS set(DEPS
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX})
endif() endif()
if (NOT WIN32) if (NOT WIN32)
set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf snappystream snappy z xxhash glog gflags protobuf snappystream snappy z xxhash
${EXTERNAL_LIB}) ${EXTERNAL_LIB})
else() else()
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
${CMAKE_STATIC_LIBRARY_PREFIX}glog ${CMAKE_STATIC_LIBRARY_PREFIX}gflags ${CMAKE_STATIC_LIBRARY_PREFIX}protobuf ${CMAKE_STATIC_LIBRARY_PREFIX}glog ${CMAKE_STATIC_LIBRARY_PREFIX}gflags ${CMAKE_STATIC_LIBRARY_PREFIX}protobuf
${EXTERNAL_LIB}) ${EXTERNAL_LIB})
# NOTE(dzhwinter) shlwapi is deprecated. # NOTE(dzhwinter) shlwapi will be deprecated.
set(DEPS ${DEPS} libcmt shlwapi) set(DEPS ${DEPS} libcmt shlwapi)
endif(NOT WIN32) endif(NOT WIN32)
if(WITH_GPU) if(WITH_GPU)
...@@ -129,8 +128,8 @@ if(WITH_GPU) ...@@ -129,8 +128,8 @@ if(WITH_GPU)
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
else() else()
set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} )
set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} )
set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} )
endif() endif()
endif() endif()
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#define GOOGLE_GLOG_DLL_DECL
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <chrono> // NOLINT
#include <fstream>
#include <iostream>
#include <thread> // NOLINT
#include <utility>
#include "paddle/fluid/inference/paddle_inference_api.h"
namespace paddle {
NativeConfig GetConfig() {
NativeConfig config;
config.prog_file = "hs_lb_without_bn_cudnn/__model__";
config.param_file = "hs_lb_without_bn_cudnn/__params__";
config.fraction_of_gpu_memory = 0.0;
config.use_gpu = true;
config.device = 0;
return config;
}
using Time = decltype(std::chrono::high_resolution_clock::now());
Time TimeNow() { return std::chrono::high_resolution_clock::now(); }
double TimeDiff(Time t1, Time t2) {
typedef std::chrono::microseconds ms;
auto diff = t2 - t1;
ms counter = std::chrono::duration_cast<ms>(diff);
return counter.count() / 1000.0;
}
std::vector<PaddleTensor> PrepareData() {
int height = 449;
int width = 581;
std::vector<float> data;
for (int i = 0; i < 3 * height * width; ++i) {
data.push_back(0.0);
}
PaddleTensor tensor;
tensor.shape = std::vector<int>({batch_size, 3, height, width});
tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width);
std::copy(data.begin(), data.end(), static_cast<float*>(tensor.data.data()));
tensor.dtype = PaddleDType::FLOAT32;
std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor);
return std::move(paddle_tensor_feeds);
}
void TestNaive(int batch_size, int thread_num) {
NativeConfig config = GetConfig();
int num_jobs = thread_num; // parallel jobs.
constexpr int epoches = 10; // each job run epoches.
std::vector<std::thread> threads;
std::vector<std::unique_ptr<PaddlePredictor>> predictors;
for (int tid = 0; tid < num_jobs; ++tid) {
auto& pred = CreatePaddlePredictor<NativeConfig>(config);
predictors.emplace_back(std::move(pred));
}
auto time1 = TimeNow();
for (int tid = 0; tid < num_jobs; ++tid) {
threads.emplace_back([&, tid]() {
auto& predictor = predictors[tid];
PaddleTensor tensor_out;
std::vector<PaddleTensor> outputs(1, tensor_out);
for (size_t i = 0; i < epoches; i++) {
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
VLOG(3) << "tid : " << tid << " run: " << i << "finished";
ASSERT_EQ(outputs.size(), 1UL);
}
});
}
for (int i = 0; i < num_jobs; ++i) {
threads[i].join();
}
auto time2 = TimeNow();
VLOG(3) << "Thread num " << thread_num << "total time cost"
<< (time2 - time1);
}
} // namespace paddle
int main(int argc, char** argv) {
paddle::TestNaive(1, 1); // single thread.
paddle::TestNaive(1, 5); // 5 threads.
return 0;
}
...@@ -14,40 +14,26 @@ ...@@ -14,40 +14,26 @@
#pragma once #pragma once
#define GLOG_NO_ABBREVIATED_SEVERITIES
#define GOOGLE_GLOG_DLL_DECL
#include <glog/logging.h> #include <glog/logging.h>
#if !defined(_WIN32) #if !defined(_WIN32)
#include <sys/time.h> #include <sys/time.h>
#else #else
#endif #endif
#include <algorithm>
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/inference/api/timer.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h" //NOLINT
namespace paddle { namespace paddle {
namespace inference { namespace inference {
// Timer for timer
class Timer {
public:
std::chrono::high_resolution_clock::time_point start;
std::chrono::high_resolution_clock::time_point startu;
void tic() { start = std::chrono::high_resolution_clock::now(); }
double toc() {
startu = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time_span =
std::chrono::duration_cast<std::chrono::duration<double>>(startu -
start);
double used_time_ms = static_cast<double>(time_span.count()) * 1000.0;
return used_time_ms;
}
};
static void split(const std::string &str, char sep, static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) { std::vector<std::string> *pieces) {
pieces->clear(); pieces->clear();
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono> // NOLINT
namespace paddle {
namespace inference {
// Timer for timer
class Timer {
public:
std::chrono::high_resolution_clock::time_point start;
std::chrono::high_resolution_clock::time_point startu;
void tic() { start = std::chrono::high_resolution_clock::now(); }
double toc() {
startu = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time_span =
std::chrono::duration_cast<std::chrono::duration<double>>(startu -
start);
double used_time_ms = static_cast<double>(time_span.count()) * 1000.0;
return used_time_ms;
}
};
} // namespace inference
} // namespace paddle
...@@ -11,7 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define GLOG_NO_ABBREVIATED_SEVERITIES
#define GOOGLE_GLOG_DLL_DECL
#include "paddle/fluid/memory/detail/buddy_allocator.h" #include "paddle/fluid/memory/detail/buddy_allocator.h"
#include "glog/logging.h" #include "glog/logging.h"
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define GLOG_NO_ABBREVIATED_SEVERITIES
#define GOOGLE_GLOG_DLL_DECL
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/memory/detail/memory_block.h" #include "paddle/fluid/memory/detail/memory_block.h"
#include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/assert.h"
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define GLOG_NO_ABBREVIATED_SEVERITIES #define GLOG_NO_ABBREVIATED_SEVERITIES
#define GOOGLE_GLOG_DLL_DECL
#include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/memory/detail/system_allocator.h"
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
......
...@@ -54,6 +54,7 @@ class CastOpKernel : public framework::OpKernel<InT> { ...@@ -54,6 +54,7 @@ class CastOpKernel : public framework::OpKernel<InT> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType( framework::VisitDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")), context.Attr<int>("out_dtype")),
......
...@@ -35,12 +35,12 @@ namespace operators { ...@@ -35,12 +35,12 @@ namespace operators {
template <typename T> template <typename T>
__device__ bool GT_E(T a, T b) { __device__ bool GT_E(T a, T b) {
return (a > b) || Eigen::numext::abs(a - b) < 1e-4; return (a > b) || fabsf(static_cast<float>(a - b)) < 1e-4;
} }
template <typename T> template <typename T>
__device__ bool LT_E(T a, T b) { __device__ bool LT_E(T a, T b) {
return (a < b) || Eigen::numext::abs(a - b) < 1e-4; return (a < b) || fabsf(static_cast<float>(a - b)) < 1e-4;
} }
template <typename T> template <typename T>
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include <vector> #include <vector>
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <fstream> #include <fstream>
#include <memory>
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -32,9 +33,15 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -32,9 +33,15 @@ class LoadCombineOp : public framework::OperatorBase {
const platform::Place &place) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto load_as_fp16 = Attr<bool>("load_as_fp16"); auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto format = Attr<std::string>("format");
std::ifstream fin(filename); std::unique_ptr<std::ifstream> fin;
PADDLE_ENFORCE(static_cast<bool>(fin), if (format == "windows") {
fin.reset(new std::ifstream(filename,
std::ios_base::in | std::ios_base::binary));
} else {
fin.reset(new std::ifstream(filename));
}
PADDLE_ENFORCE(static_cast<bool>(*fin),
"Cannot open file %s for load_combine op", filename); "Cannot open file %s for load_combine op", filename);
auto out_var_names = Outputs("Out"); auto out_var_names = Outputs("Out");
...@@ -54,11 +61,11 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -54,11 +61,11 @@ class LoadCombineOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = out_var->GetMutable<framework::LoDTensor>();
// Error checking // Error checking
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot read more from file %s", PADDLE_ENFORCE(static_cast<bool>(*fin), "Cannot read more from file %s",
filename); filename);
// Get data from fin to tensor // Get data from fin to tensor
DeserializeFromStream(fin, tensor, dev_ctx); DeserializeFromStream(*fin, tensor, dev_ctx);
auto in_dtype = framework::ToDataType(tensor->type()); auto in_dtype = framework::ToDataType(tensor->type());
auto out_dtype = auto out_dtype =
...@@ -103,6 +110,18 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -103,6 +110,18 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"LoDTensors will be loaded from \"file_path\".") "LoDTensors will be loaded from \"file_path\".")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
AddAttr<std::string>("format",
R"DOC((windows|linux)" "saved model file format
windows and linux file newline symbol is
different. windows(newline is \n\r) or linux(newline is \r)
So if you set attribute format to windows, then we saved model file in binary.
It can be used both linux and windows. If you set format to linux,
it will save file in normal file, newline symbol is \r. Need to note
that these two format is not inter-compatible.)DOC")
.SetDefault("linux")
.AddCustomChecker([](const std::string &s) {
return s == "windows" || s == "linux";
});
AddComment(R"DOC( AddComment(R"DOC(
LoadCombine Operator. LoadCombine Operator.
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <fstream> #include <fstream>
#include <memory>
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -34,8 +35,15 @@ class LoadOp : public framework::OperatorBase { ...@@ -34,8 +35,15 @@ class LoadOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change // FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream. // it to save an output stream.
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); auto format = Attr<std::string>("format");
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", std::unique_ptr<std::ifstream> fin;
if (format == "windows") {
fin.reset(new std::ifstream(filename,
std::ios_base::in | std::ios_base::binary));
} else {
fin.reset(new std::ifstream(filename));
}
PADDLE_ENFORCE(static_cast<bool>(*fin), "Cannot open file %s for load op",
filename); filename);
auto out_var_name = Output("Out"); auto out_var_name = Output("Out");
...@@ -44,9 +52,9 @@ class LoadOp : public framework::OperatorBase { ...@@ -44,9 +52,9 @@ class LoadOp : public framework::OperatorBase {
out_var_name); out_var_name);
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var); LoadLodTensor(*fin, place, out_var);
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var); LoadSelectedRows(*fin, place, out_var);
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
false, false,
...@@ -110,6 +118,18 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,6 +118,18 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
R"(Variable will be loaded from "file_path")") R"(Variable will be loaded from "file_path")")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
AddAttr<std::string>("format",
R"DOC((windows|linux)" "saved model file format
windows and linux file newline symbol is
different. windows(newline is \n\r) or linux(newline is \r)
So if you set attribute format to windows, then we saved model file in binary.
It can be used both linux and windows. If you set format to linux,
it will save file in normal file, newline symbol is \r. Need to note
that these two format is not inter-compatible.)DOC")
.SetDefault("linux")
.AddCustomChecker([](const std::string &s) {
return s == "windows" || s == "linux";
});
AddComment( AddComment(
"Load operator will load a LoDTensor / SelectedRows variable from disk " "Load operator will load a LoDTensor / SelectedRows variable from disk "
"file."); "file.");
......
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
......
...@@ -57,9 +57,6 @@ math_library(sequence_padding) ...@@ -57,9 +57,6 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function) math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale) math_library(sequence_scale)
math_library(softmax DEPS math_function) math_library(softmax DEPS math_function)
if (NOT WIN32)
math_library(matrix_bit_code)
endif (NOT WIN32)
math_library(unpooling) math_library(unpooling)
math_library(vol2col) math_library(vol2col)
...@@ -76,12 +73,11 @@ endif() ...@@ -76,12 +73,11 @@ endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
if (NOT WIN32) if (NOT WIN32)
math_library(matrix_bit_code)
set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc) set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc)
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
if(WITH_XBYAK) if(WITH_XBYAK)
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc) list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
list(APPEND JIT_KERNEL_DEPS xbyak) list(APPEND JIT_KERNEL_DEPS xbyak)
endif() endif()
cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS}) endif (NOT WIN32)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
endif()
...@@ -18,10 +18,6 @@ limitations under the License. */ ...@@ -18,10 +18,6 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
......
...@@ -15,13 +15,10 @@ limitations under the License. */ ...@@ -15,13 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <math.h> #include <math.h>
#include <string> #include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
......
...@@ -25,10 +25,6 @@ limitations under the License. */ ...@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
......
...@@ -16,9 +16,6 @@ limitations under the License. */ ...@@ -16,9 +16,6 @@ limitations under the License. */
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -263,6 +260,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -263,6 +260,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
} \ } \
} }
#ifndef _WIN32 // commented out crf decoding
#ifdef __AVX__ #ifdef __AVX__
INTRIAVX_FLOAT(kEQ8); INTRIAVX_FLOAT(kEQ8);
INTRIAVX_FLOAT(kGT8LT16); INTRIAVX_FLOAT(kGT8LT16);
...@@ -275,6 +273,7 @@ INTRIAVX2_FLOAT(jit::avx2, kGT8LT16); ...@@ -275,6 +273,7 @@ INTRIAVX2_FLOAT(jit::avx2, kGT8LT16);
INTRIAVX2_FLOAT(jit::avx2, kEQ16); INTRIAVX2_FLOAT(jit::avx2, kEQ16);
INTRIAVX2_FLOAT(jit::avx2, kGT16); INTRIAVX2_FLOAT(jit::avx2, kGT16);
#endif #endif
#endif // WIN32
#ifdef __AVX512F__ #ifdef __AVX512F__
INTRIAVX2_FLOAT(jit::avx512f, kEQ8); INTRIAVX2_FLOAT(jit::avx512f, kEQ8);
INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16); INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16);
......
...@@ -20,10 +20,6 @@ limitations under the License. */ ...@@ -20,10 +20,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -66,14 +62,18 @@ namespace detail { ...@@ -66,14 +62,18 @@ namespace detail {
#ifdef __AVX__ #ifdef __AVX__
#if defined(_WIN32)
#define ALIGN32 __declspec(align(32))
#else
#define ALIGN32 __attribute__((aligned(32))) #define ALIGN32 __attribute__((aligned(32)))
#endif // _WIN32
#define _PS256_CONST(Name, Val) \ #define _PS256_CONST(Name, Val) \
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \ static const float ALIGN32 _ps256_##Name[8] = {Val, Val, Val, Val, \
Val, Val, Val, Val} Val, Val, Val, Val}
#define _PI256_CONST(Name, Val) \ #define _PI256_CONST(Name, Val) \
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \ static const int ALIGN32 _pi256_##Name[8] = {Val, Val, Val, Val, \
Val, Val, Val, Val} Val, Val, Val, Val}
_PI256_CONST(0x7f, 0x7f); _PI256_CONST(0x7f, 0x7f);
...@@ -98,7 +98,7 @@ typedef union imm_xmm_union { ...@@ -98,7 +98,7 @@ typedef union imm_xmm_union {
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \ #define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \ { \
imm_xmm_union u ALIGN32; \ imm_xmm_union ALIGN32 u; \
u.imm = imm_; \ u.imm = imm_; \
xmm0_ = u.xmm[0]; \ xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \ xmm1_ = u.xmm[1]; \
...@@ -106,7 +106,7 @@ typedef union imm_xmm_union { ...@@ -106,7 +106,7 @@ typedef union imm_xmm_union {
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \ #define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \ { \
imm_xmm_union u ALIGN32; \ imm_xmm_union ALIGN32 u; \
u.xmm[0] = xmm0_; \ u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \ u.xmm[1] = xmm1_; \
imm_ = u.imm; \ imm_ = u.imm; \
...@@ -508,12 +508,14 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -508,12 +508,14 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_->Compute(-1.f, y, y); \ vaddbias_->Compute(-1.f, y, y); \
} }
#ifndef __WIN32
#ifdef __AVX__ #ifdef __AVX__
INTRI8_FLOAT(jit::avx, detail::ExpAVX); INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX); INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
#endif #endif // AVX
#endif // WIN32
#ifdef __AVX2__ #ifdef __AVX2__
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
......
...@@ -18,10 +18,6 @@ limitations under the License. */ ...@@ -18,10 +18,6 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
...@@ -16,18 +16,12 @@ limitations under the License. */ ...@@ -16,18 +16,12 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h" #include "paddle/fluid/operators/math/sequence_pooling.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#if defined(__FLT_MAX__)
#define FLT_MAX __FLT_MAX__
#else
#include <float.h>
#include <limits>
#endif
template <typename T> template <typename T>
struct MaxPoolFunctor { struct MaxPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start, HOSTDEVICE void operator()(const T* input, const size_t start,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <iostream>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <fstream> #include <fstream>
#include <memory>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -41,6 +42,7 @@ class SaveCombineOp : public framework::OperatorBase { ...@@ -41,6 +42,7 @@ class SaveCombineOp : public framework::OperatorBase {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
auto save_as_fp16 = Attr<bool>("save_as_fp16"); auto save_as_fp16 = Attr<bool>("save_as_fp16");
auto format = Attr<std::string>("format");
bool is_present = FileExists(filename); bool is_present = FileExists(filename);
if (is_present && !overwrite) { if (is_present && !overwrite) {
...@@ -49,8 +51,14 @@ class SaveCombineOp : public framework::OperatorBase { ...@@ -49,8 +51,14 @@ class SaveCombineOp : public framework::OperatorBase {
} }
MkDirRecursively(DirName(filename).c_str()); MkDirRecursively(DirName(filename).c_str());
std::ofstream fout(filename); std::unique_ptr<std::ofstream> fout;
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write", if (format == "windows") {
fout.reset(new std::ofstream(filename,
std::ios_base::out | std::ios_base::binary));
} else {
fout.reset(new std::ofstream(filename));
}
PADDLE_ENFORCE(static_cast<bool>(*fout), "Cannot open %s to write",
filename); filename);
auto inp_var_names = Inputs("X"); auto inp_var_names = Inputs("X");
...@@ -86,12 +94,11 @@ class SaveCombineOp : public framework::OperatorBase { ...@@ -86,12 +94,11 @@ class SaveCombineOp : public framework::OperatorBase {
// copy LoD info to the new tensor // copy LoD info to the new tensor
out.set_lod(tensor.lod()); out.set_lod(tensor.lod());
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
framework::SerializeToStream(fout, out, dev_ctx); framework::SerializeToStream(*fout, out, dev_ctx);
} else { } else {
framework::SerializeToStream(fout, tensor, dev_ctx); framework::SerializeToStream(*fout, tensor, dev_ctx);
} }
} }
fout.close();
} }
}; };
...@@ -124,6 +131,18 @@ to a file on disk. ...@@ -124,6 +131,18 @@ to a file on disk.
"The \"file_path\" where the LoDTensor variables will be saved.") "The \"file_path\" where the LoDTensor variables will be saved.")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
AddAttr<std::string>("format",
R"DOC((windows|linux)" "saved model file format
windows and linux file newline symbol is
different. windows(newline is \n\r) or linux(newline is \r)
So if you set attribute format to windows, then we saved model file in binary.
It can be used both linux and windows. If you set format to linux,
it will save file in normal file, newline symbol is \r. Need to note
that these two format is not inter-compatible.)DOC")
.SetDefault("linux")
.AddCustomChecker([](const std::string &s) {
return s == "windows" || s == "linux";
});
} }
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <fstream> #include <fstream>
#include <memory>
#include <numeric> #include <numeric>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -64,6 +65,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -64,6 +65,7 @@ class SaveOp : public framework::OperatorBase {
framework::Variable *var) const { framework::Variable *var) const {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
auto format = Attr<std::string>("format");
if (FileExists(filename) && !overwrite) { if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
...@@ -80,8 +82,14 @@ class SaveOp : public framework::OperatorBase { ...@@ -80,8 +82,14 @@ class SaveOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change // FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream. // it to save an output stream.
std::ofstream fout(filename); std::unique_ptr<std::ofstream> fout;
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write", if (format == "windows") {
fout.reset(new std::ofstream(filename,
std::ios_base::out | std::ios_base::binary));
} else {
fout.reset(new std::ofstream(filename));
}
PADDLE_ENFORCE(static_cast<bool>(*fout), "Cannot open %s to write",
filename); filename);
auto save_as_fp16 = Attr<bool>("save_as_fp16"); auto save_as_fp16 = Attr<bool>("save_as_fp16");
...@@ -95,11 +103,10 @@ class SaveOp : public framework::OperatorBase { ...@@ -95,11 +103,10 @@ class SaveOp : public framework::OperatorBase {
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
// copy LoD info to the new tensor // copy LoD info to the new tensor
out.set_lod(tensor.lod()); out.set_lod(tensor.lod());
framework::SerializeToStream(fout, out, dev_ctx); framework::SerializeToStream(*fout, out, dev_ctx);
} else { } else {
framework::SerializeToStream(fout, tensor, dev_ctx); framework::SerializeToStream(*fout, tensor, dev_ctx);
} }
fout.close();
} }
void SaveSelectedRows(const framework::Scope &scope, void SaveSelectedRows(const framework::Scope &scope,
...@@ -110,6 +117,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -110,6 +117,7 @@ class SaveOp : public framework::OperatorBase {
lt_var != nullptr, lt_var != nullptr,
"Can not find variable kLookupTablePath for SaveSelectedRows"); "Can not find variable kLookupTablePath for SaveSelectedRows");
std::string filename = lt_var->data(); std::string filename = lt_var->data();
auto format = Attr<std::string>("format");
VLOG(4) << "SaveSelectedRows get File name: " << filename; VLOG(4) << "SaveSelectedRows get File name: " << filename;
MkDirRecursively(DirName(filename).c_str()); MkDirRecursively(DirName(filename).c_str());
...@@ -122,11 +130,16 @@ class SaveOp : public framework::OperatorBase { ...@@ -122,11 +130,16 @@ class SaveOp : public framework::OperatorBase {
// FIXME(yuyang18): We save variable to local file now, but we should change // FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream. // it to save an output stream.
std::ofstream fout(filename); std::unique_ptr<std::ofstream> fout;
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write", if (format == "windows") {
fout.reset(new std::ofstream(filename,
std::ios_base::out | std::ios_base::binary));
} else {
fout.reset(new std::ofstream(filename));
}
PADDLE_ENFORCE(static_cast<bool>(*fout), "Cannot open %s to write",
filename); filename);
framework::SerializeToStream(fout, selectedRows, dev_ctx); framework::SerializeToStream(*fout, selectedRows, dev_ctx);
fout.close();
} }
}; };
...@@ -154,6 +167,18 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file ...@@ -154,6 +167,18 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
"The \"file_path\" where the variable will be saved.") "The \"file_path\" where the variable will be saved.")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
AddAttr<std::string>("format",
R"DOC((windows|linux)" "saved model file format
windows and linux file newline symbol is
different. windows(newline is \n\r) or linux(newline is \r)
So if you set attribute format to windows, then we saved model file in binary.
It can be used both linux and windows. If you set format to linux,
it will save file in normal file, newline symbol is \r. Need to note
that these two format is not inter-compatible.)DOC")
.SetDefault("linux")
.AddCustomChecker([](const std::string &s) {
return s == "windows" || s == "linux";
});
} }
}; };
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/space_to_depth_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
class SpaceToDepthOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SpaceToDepthOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SpaceToDepthOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor");
auto blocksize = ctx->Attrs().Get<int64_t>("blocksize");
PADDLE_ENFORCE_GT(blocksize, 1, "The blocksize should be Greater than 1");
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize");
VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims
<< "Attribute blocksize" << blocksize << std::endl;
std::vector<int64_t> output_shape(4, 0); // [B,C,H,W]
output_shape[0] = x_dims[0];
output_shape[1] = x_dims[1] * blocksize * blocksize;
output_shape[2] = x_dims[2] / blocksize;
output_shape[3] = x_dims[3] / blocksize;
auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
};
class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor). The input should be a 4D tensor B * C * W * H of "
"SpaceToDepthOp "
"operator.");
AddOutput("Out",
"(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of "
"SpaceToDepthOp operator.");
AddAttr<int64_t>(
"blocksize",
"(int64_t, default 2) blocksize used to do change Space To Depth.")
.SetDefault(2)
.GreaterThan(1);
AddComment(R"DOC(
reorg operator used in Yolo v2.
The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize,
Reshape Input(X) into the shape according to Attr(blocksize). The
data in Input(X) are unchanged.
Examples:
1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the blocksize is 2, the reorg operator will transform Input(X)
into a 4-D tensor with shape [128, 2048, 13, 13] and leaving Input(X)'s data unchanged.
)DOC");
}
};
class SpaceToDepthGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp);
REGISTER_OP_CPU_KERNEL(
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, double>,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
space_to_depth_grad,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/space_to_depth_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, double>,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
space_to_depth_grad,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#define PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#endif // PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
class space_to_depth_compute {
public:
HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c,
int64_t batch, int64_t blocksize,
int64_t forward, T *out)
: x_(x),
w_(w),
h_(h),
c_(c),
batch_(batch),
blocksize_(blocksize),
forward_(forward),
out_(out) {}
HOSTDEVICE void operator()(int64_t in_index) {
int64_t out_c = c_ / (blocksize_ * blocksize_);
// calculate each dim position with index of tensor
int64_t b = in_index / (c_ * h_ * w_);
int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_);
int64_t j = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) / w_;
int64_t i = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) % w_;
int64_t c2 = k % out_c;
int64_t offset = k / out_c;
int64_t w2 = i * blocksize_ + offset % blocksize_;
int64_t h2 = j * blocksize_ + offset / blocksize_;
int64_t out_index =
w2 + w_ * blocksize_ * (h2 + h_ * blocksize_ * (c2 + out_c * b));
if (forward_)
out_[out_index] = x_[in_index];
else
out_[in_index] = x_[out_index];
}
private:
const T *x_;
int64_t w_, h_, c_, batch_, blocksize_, forward_;
T *out_;
};
template <typename DeviceContext, typename T>
class SpaceToDepthKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<framework::LoDTensor>("Out");
auto *x = context.Input<framework::LoDTensor>("X");
auto blocksize = context.Attr<int64_t>("blocksize");
auto in_dims = x->dims();
out->mutable_data(context.GetPlace(), x->type());
auto out_dims = out->dims();
auto B = in_dims[0];
auto C = in_dims[1];
auto H = in_dims[2];
auto W = in_dims[3];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(x->numel()));
auto *x_data = x->data<T>();
auto *out_data = out->data<T>();
paddle::operators::space_to_depth_compute<T> computer(
x_data, W, H, C, B, blocksize, 1, out_data);
for_range(computer);
out->Resize(out_dims);
}
};
template <typename DeviceContext, typename T>
class SpaceToDepthGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto blocksize = context.Attr<int64_t>("blocksize");
auto in_dims = d_x->dims();
d_x->mutable_data(context.GetPlace(), d_out->type());
auto B = in_dims[0];
auto C = in_dims[1];
auto H = in_dims[2];
auto W = in_dims[3];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(d_x->numel()));
auto *dx_data = d_x->data<T>();
auto *dout_data = d_out->data<T>();
paddle::operators::space_to_depth_compute<T> computer(
dout_data, W, H, C, B, blocksize, 0, dx_data);
for_range(computer);
d_x->Resize(in_dims);
}
};
} // namespace operators
} // namespace paddle
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -34,7 +34,7 @@ namespace operators { ...@@ -34,7 +34,7 @@ namespace operators {
using FluidDT = framework::proto::VarType_Type; using FluidDT = framework::proto::VarType_Type;
using TRT_DT = nvinfer1::DataType; using TRT_DT = nvinfer1::DataType;
namespace { namespace { // NOLINT
TRT_DT FluidDataType2TRT(FluidDT type) { TRT_DT FluidDataType2TRT(FluidDT type) {
switch (type) { switch (type) {
...@@ -60,7 +60,7 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t>& shape) { ...@@ -60,7 +60,7 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t>& shape) {
return nvinfer1::DimsCHW(shape[1], 1, 1); return nvinfer1::DimsCHW(shape[1], 1, 1);
} }
} // namespace } // NOLINT // namespace
using inference::Singleton; using inference::Singleton;
using inference::tensorrt::TRT_EngineManager; using inference::tensorrt::TRT_EngineManager;
......
...@@ -16,6 +16,18 @@ limitations under the License. */ ...@@ -16,6 +16,18 @@ limitations under the License. */
#include <stddef.h> #include <stddef.h>
#ifdef _WIN32
#if defined(__AVX2__)
#include <immintrin.h> //avx2
#elif defined(__AVX__)
#include <intrin.h> //avx
#endif // AVX
#else // WIN32
#ifdef __AVX__
#include <immintrin.h>
#endif
#endif // WIN32
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
...@@ -59,6 +59,7 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { ...@@ -59,6 +59,7 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
#define CUDNN_VERSION_MIN(major, minor, patch) \ #define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))
#if !defined(_WIN32)
#define CUDNN_ENFORCE(condition) \ #define CUDNN_ENFORCE(condition) \
do { \ do { \
cudnnStatus_t status = condition; \ cudnnStatus_t status = condition; \
...@@ -66,6 +67,16 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { ...@@ -66,6 +67,16 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
PADDLE_THROW(::paddle::platform::cudnnGetErrorString(status)); \ PADDLE_THROW(::paddle::platform::cudnnGetErrorString(status)); \
} \ } \
} while (false) } while (false)
#else
// windows
#define CUDNN_ENFORCE(condition) \
do { \
cudnnStatus_t status = condition; \
if (status != CUDNN_STATUS_SUCCESS) { \
std::cerr << ::paddle::platform::cudnnGetErrorString(status); \
} \
} while (false)
#endif
enum class DataLayout { // Not use enum class DataLayout { // Not use
kNHWC, kNHWC,
......
...@@ -55,7 +55,6 @@ DeviceContextPool::DeviceContextPool( ...@@ -55,7 +55,6 @@ DeviceContextPool::DeviceContextPool(
for (auto& p : places) { for (auto& p : places) {
set.insert(p); set.insert(p);
} }
for (auto& p : set) { for (auto& p : set) {
if (platform::is_cpu_place(p)) { if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -205,7 +204,9 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) ...@@ -205,7 +204,9 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
<< ", Runtime Version: " << runtime_version_ / 1000 << ", Runtime Version: " << runtime_version_ / 1000
<< "." << (runtime_version_ % 100) / 10; << "." << (runtime_version_ % 100) / 10;
#ifndef _WIN32
callback_manager_.reset(new StreamCallbackManager(stream_)); callback_manager_.reset(new StreamCallbackManager(stream_));
#endif // NOT WIN32
} }
CUDADeviceContext::~CUDADeviceContext() { CUDADeviceContext::~CUDADeviceContext() {
......
...@@ -32,7 +32,7 @@ limitations under the License. */ ...@@ -32,7 +32,7 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/stream_callback_manager.h" #include "paddle/fluid/platform/stream_callback_manager.h"
#endif #endif
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
...@@ -173,6 +173,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -173,6 +173,7 @@ class CUDADeviceContext : public DeviceContext {
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
#ifndef _WIN32
template <typename Callback> template <typename Callback>
void AddStreamCallback(Callback&& callback) const { void AddStreamCallback(Callback&& callback) const {
std::lock_guard<std::mutex> guard(callback_mtx_); std::lock_guard<std::mutex> guard(callback_mtx_);
...@@ -183,6 +184,16 @@ class CUDADeviceContext : public DeviceContext { ...@@ -183,6 +184,16 @@ class CUDADeviceContext : public DeviceContext {
std::lock_guard<std::mutex> guard(callback_mtx_); std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->Wait(); callback_manager_->Wait();
} }
#else
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
// ugly empty functor.
}
void WaitStreamCallback() const {
// ugly empty functor.
}
#endif
private: private:
CUDAPlace place_; CUDAPlace place_;
...@@ -201,10 +212,12 @@ class CUDADeviceContext : public DeviceContext { ...@@ -201,10 +212,12 @@ class CUDADeviceContext : public DeviceContext {
mutable std::mutex mtx_; mutable std::mutex mtx_;
#ifndef _WIN32
// This lock is only used by callback // This lock is only used by callback
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes // If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
mutable std::mutex callback_mtx_; mutable std::mutex callback_mtx_;
std::unique_ptr<StreamCallbackManager> callback_manager_; std::unique_ptr<StreamCallbackManager> callback_manager_;
#endif
}; };
template <> template <>
......
...@@ -127,7 +127,7 @@ struct EOFException : public std::exception { ...@@ -127,7 +127,7 @@ struct EOFException : public std::exception {
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else #else
// there is no equivalent intrinsics in msvc. // there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition == 0) #define UNLIKELY(condition) ((condition) == 0)
#endif #endif
#if !defined(_WIN32) #if !defined(_WIN32)
......
...@@ -175,7 +175,7 @@ void InitGLOG(const std::string &prog_name) { ...@@ -175,7 +175,7 @@ void InitGLOG(const std::string &prog_name) {
// glog will not hold the ARGV[0] inside. // glog will not hold the ARGV[0] inside.
// Use strdup to alloc a new string. // Use strdup to alloc a new string.
google::InitGoogleLogging(strdup(prog_name.c_str())); google::InitGoogleLogging(strdup(prog_name.c_str()));
#ifndef _WIN32 #if !defined(_WIN32)
google::InstallFailureSignalHandler(); google::InstallFailureSignalHandler();
#endif #endif
} }
......
...@@ -28,3 +28,16 @@ limitations under the License. */ ...@@ -28,3 +28,16 @@ limitations under the License. */
#if defined(__FLT_MAX__) #if defined(__FLT_MAX__)
#define FLT_MAX __FLT_MAX__ #define FLT_MAX __FLT_MAX__
#endif // __FLT_MAX__ #endif // __FLT_MAX__
#ifdef _WIN32
#if defined(PADDLE_COMPILE)
// by default, msvc has predefined macro _LIB for static library
// only shared library need to export and import symbols
// static library export all symbols by default.
#define PADDLE_DLL __declspec(dllexport)
#else
#define PADDLE_DLL __declspec(dllimport)
#endif
#else
#define PADDLE_DLL
#endif
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
#pragma once #pragma once
#include <cstdio> #include <cstdio>
#include <stdexcept>
#include <memory> #include <memory>
#include <memory> // NOLINT
#include <stdexcept>
#include <string> #include <string>
#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h #define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h
#define GOOGLE_GLOG_DLL_DECL
#include "glog/logging.h" #include "glog/logging.h"
#if !defined(_WIN32) #if !defined(_WIN32)
...@@ -61,7 +62,6 @@ static void *dlopen(const char *filename, int flag) { ...@@ -61,7 +62,6 @@ static void *dlopen(const char *filename, int flag) {
} }
return reinterpret_cast<void *>(hModule); return reinterpret_cast<void *>(hModule);
} }
#endif // !_WIN32 #endif // !_WIN32
static void ExecShellCommand(const std::string &cmd, std::string *message) { static void ExecShellCommand(const std::string &cmd, std::string *message) {
......
...@@ -152,6 +152,7 @@ __all__ = [ ...@@ -152,6 +152,7 @@ __all__ = [
'mul', 'mul',
'sigmoid_cross_entropy_with_logits', 'sigmoid_cross_entropy_with_logits',
'maxout', 'maxout',
'space_to_depth',
'affine_grid', 'affine_grid',
'sequence_reverse', 'sequence_reverse',
'affine_channel', 'affine_channel',
...@@ -3064,7 +3065,7 @@ def sequence_pad(x, pad_value, maxlen=None, name=None): ...@@ -3064,7 +3065,7 @@ def sequence_pad(x, pad_value, maxlen=None, name=None):
x = fluid.layers.data(name='y', shape=[10, 5], x = fluid.layers.data(name='y', shape=[10, 5],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
pad_value = fluid.layers.assign( pad_value = fluid.layers.assign(
input=numpy.array([0], dtype=numpy.float32)) input=numpy.array([0.0], dtype=numpy.float32))
out = fluid.layers.sequence_pad(x=x, pad_value=pad_value) out = fluid.layers.sequence_pad(x=x, pad_value=pad_value)
""" """
...@@ -7679,6 +7680,66 @@ def maxout(x, groups, name=None): ...@@ -7679,6 +7680,66 @@ def maxout(x, groups, name=None):
return out return out
def space_to_depth(x, blocksize, name=None):
"""
Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width]
This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the
input LoDtensor where values from the height and width dimensions are moved to the channel dimension.
The attr blocksize indicates the input block size.
space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according
to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]:
space_to_depth is used to This operation is useful for resizing the activations between convolutions
(but keeping all data)
- Non-overlapping blocks of size block_size x block size are rearranged into depth at each location.
- The depth of the output tensor is block_size * block_size * input channel
- The Y, X coordinates within each block of the input become the high order component of the output channel index
- channel should be divisible by square of blocksize
- height, width should be divsible by blocksize
Args:
x(variable): The input LoDtensor.
blocksize(variable): The blocksize to select the element on each feature map should be > 2
Returns:
Variable: The output LoDtensor.
Raises:
TypeError: blocksize type must be a long.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[1, 4, 2, 2], dtype='float32')
space_to_depthed = fluid.layers.space_to_depth(
x=data, blocksize=2)
"""
helper = LayerHelper("space_to_depth", **locals())
if not (isinstance(blocksize, int)):
raise ValueError("blocksize must be a python Int")
if name is None:
out = helper.create_variable_for_type_inference(
dtype=x.dtype) #fix create
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="space_to_depth",
inputs={"X": x},
attrs={"blocksize": blocksize},
outputs={"Out": out})
return out
@templatedoc() @templatedoc()
def sequence_reverse(x, name=None): def sequence_reverse(x, name=None):
""" """
......
...@@ -108,6 +108,8 @@ class OpDescCreationMethod(object): ...@@ -108,6 +108,8 @@ class OpDescCreationMethod(object):
new_attr.i = user_defined_attr new_attr.i = user_defined_attr
elif attr.type == framework_pb2.FLOAT: elif attr.type == framework_pb2.FLOAT:
new_attr.f = user_defined_attr new_attr.f = user_defined_attr
elif attr.type == framework_pb2.LONG:
new_attr.l = user_defined_attr
elif attr.type == framework_pb2.STRING: elif attr.type == framework_pb2.STRING:
new_attr.s = user_defined_attr new_attr.s = user_defined_attr
elif attr.type == framework_pb2.BOOLEAN: elif attr.type == framework_pb2.BOOLEAN:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册