未验证 提交 2b2b4ca2 编写于 作者: Z zhoukunsheng 提交者: GitHub

Merge branch 'develop' into rsqrt

...@@ -71,7 +71,8 @@ option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plan ...@@ -71,7 +71,8 @@ option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plan
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(ON_INFER "Turn on inference optimization." OFF) option(ON_INFER "Turn on inference optimization." OFF)
option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF)
option(WITH_HIGH_LEVEL_API_TEST "Test fluid python high-level api interface" OFF)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON) option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON)
......
...@@ -34,7 +34,7 @@ ExternalProject_Add( ...@@ -34,7 +34,7 @@ ExternalProject_Add(
BUILD_IN_SOURCE 1 BUILD_IN_SOURCE 1
) )
ADD_LIBRARY(dgc SHARED IMPORTED GLOBAL) ADD_LIBRARY(dgc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES}) SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
ADD_DEPENDENCIES(dgc extern_dgc) ADD_DEPENDENCIES(dgc extern_dgc)
......
...@@ -201,7 +201,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -201,7 +201,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64") SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64")
ENDIF() ENDIF()
SET(PROTOBUF_REPO "https://github.com/google/protobuf.git") SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git")
SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546") SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546")
ExternalProject_Add( ExternalProject_Add(
...@@ -221,6 +221,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -221,6 +221,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_INSTALL_LIBDIR=lib
-DBUILD_SHARED_LIBS=OFF
CMAKE_CACHE_ARGS CMAKE_CACHE_ARGS
-DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR}
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
......
此差异已折叠。
#windows treat symbolic file as a real file, which is different with unix #windows treat symbolic file as a real file, which is different with unix
#We create a hidden file and compile it instead of origin source file. #We create a hidden file and compile it instead of origin source file.
function(windows_symbolic TARGET) function(windows_symbolic TARGET)
...@@ -22,9 +23,13 @@ endfunction() ...@@ -22,9 +23,13 @@ endfunction()
add_subdirectory(ir) add_subdirectory(ir)
add_subdirectory(details) add_subdirectory(details)
add_subdirectory(fleet)
add_subdirectory(io)
#ddim lib #ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(async_executor_proto SRCS data_feed.proto) proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
...@@ -129,9 +134,11 @@ cc_test(version_test SRCS version_test.cc DEPS version) ...@@ -129,9 +134,11 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto)
#Generate an empty \ #Generate an empty \
#__init__.py to make framework_py_proto as a valid python module. #__init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
...@@ -165,29 +172,43 @@ else() ...@@ -165,29 +172,43 @@ else()
endif() endif()
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector) cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS}) dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS}) cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
target_link_libraries(executor while_op_helper executor_gc_helper) target_link_libraries(executor while_op_helper executor_gc_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy graph build_strategy
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)
if(WITH_PSLIB) cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib timer) executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
else() trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper timer) downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
endif(WITH_PSLIB) data_set.cc dataset_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
variable_helper timer fs shell)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor) cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
...@@ -214,18 +235,18 @@ cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog) ...@@ -214,18 +235,18 @@ cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
COMMAND git rev-parse --abbrev-ref HEAD COMMAND git rev-parse --abbrev-ref HEAD
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_BRANCH OUTPUT_VARIABLE PADDLE_BRANCH
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
) )
# Get the latest abbreviated commit hash of the working branch # Get the latest abbreviated commit hash of the working branch
execute_process( execute_process(
COMMAND git log -1 --format=%h COMMAND git log -1 --format=%h
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_COMMIT OUTPUT_VARIABLE PADDLE_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
) )
message(STATUS "commit: ${PADDLE_COMMIT}") message(STATUS "commit: ${PADDLE_COMMIT}")
message(STATUS "branch: ${PADDLE_BRANCH}") message(STATUS "branch: ${PADDLE_BRANCH}")
......
...@@ -26,212 +26,44 @@ limitations under the License. */ ...@@ -26,212 +26,44 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place) AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place)
: root_scope_(scope), place_(place) {} : root_scope_(scope), place_(place) {}
void AsyncExecutor::CreateThreads(
ExecutorThreadWorker* worker, const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names, Scope* root_scope,
const int thread_index, const bool debug) {
worker->SetThreadId(thread_index);
worker->SetDebug(debug);
worker->SetRootScope(root_scope);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
#ifdef PADDLE_WITH_PSLIB
worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread);
worker->SetParamConfig(&_param_config);
#endif
}
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
const int thread_num, const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist) {
readers.resize(thread_num);
for (size_t i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
readers[i]->Init(data_feed_desc); // set batch_size and queue_size here
}
readers[0]->SetFileList(filelist);
}
#ifdef PADDLE_WITH_PSLIB
void AsyncExecutor::InitServer(const std::string& dist_desc, int index) { void AsyncExecutor::InitServer(const std::string& dist_desc, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>( fleet_ptr_ = FleetWrapper::GetInstance();
new paddle::distributed::PSlib()); fleet_ptr_->InitServer(dist_desc, index);
_pslib_ptr->init_server(dist_desc, index);
InitParamConfig();
} }
void AsyncExecutor::InitWorker(const std::string& dist_desc, void AsyncExecutor::InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, const std::vector<uint64_t>& host_sign_list,
int node_num, int index) { int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>( fleet_ptr_ = FleetWrapper::GetInstance();
new paddle::distributed::PSlib()); fleet_ptr_->InitWorker(dist_desc, host_sign_list, node_num, index);
_pslib_ptr->init_worker(
dist_desc, const_cast<uint64_t*>(host_sign_list.data()), node_num, index);
InitParamConfig();
} }
uint64_t AsyncExecutor::StartServer() { return _pslib_ptr->run_server(); } uint64_t AsyncExecutor::StartServer() { return fleet_ptr_->RunServer(); }
void AsyncExecutor::StopServer() { _pslib_ptr->stop_server(); } void AsyncExecutor::StopServer() { fleet_ptr_->StopServer(); }
void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list, void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) { int node_num) {
_pslib_ptr->gather_servers(const_cast<uint64_t*>(host_sign_list.data()), fleet_ptr_->GatherServers(host_sign_list, node_num);
node_num);
}
void AsyncExecutor::InitParamConfig() {
for (int i = 0; i < _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param_size();
++i) {
if (_pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.table_class()
.find("SparseTable") != -1) {
_param_config.fea_dim = _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.accessor()
.fea_dim();
break;
}
}
_param_config.slot_dim = _param_config.fea_dim - 2;
_param_config.tmp_push_dense_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size();
++t) {
_param_config.skip_op.push_back(
_pslib_ptr->get_param()->trainer_param().skip_op(t));
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
std::vector<std::string> tmp_sparse_variable_name;
for (int i = 0u; i < table.slot_value_size(); ++i) {
tmp_sparse_variable_name.push_back(table.slot_value(i));
_param_config.slot_alias_to_table[table.slot_key(i)] = table.table_id();
}
std::vector<std::string> tmp_sparse_gradient_variable_name;
for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
tmp_sparse_gradient_variable_name.push_back(table.slot_gradient(i));
}
_param_config.slot_input_vec[table.table_id()] =
std::move(tmp_sparse_variable_name);
_param_config.gradient_var[table.table_id()] =
std::move(tmp_sparse_gradient_variable_name);
_param_config.sparse_table_id.push_back(table.table_id());
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t);
std::vector<std::string> tmp_dense_variable_name;
for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
tmp_dense_variable_name.push_back(table.dense_variable_name(i));
}
std::vector<std::string> tmp_dense_gradient_variable_name;
for (auto i = 0u; i < table.dense_gradient_variable_name_size(); ++i) {
tmp_dense_gradient_variable_name.push_back(
table.dense_gradient_variable_name(i));
}
_param_config.dense_variable_name[table.table_id()] =
std::move(tmp_dense_variable_name);
_param_config.dense_gradient_variable_name[table.table_id()] =
std::move(tmp_dense_gradient_variable_name);
_param_config.dense_table_id.push_back(table.table_id());
_param_config.dense_table_size.push_back(table.fea_dim());
}
} }
void AsyncExecutor::InitModel() { // todo InitModel
for (auto table_id : _param_config.dense_table_id) { void AsyncExecutor::InitModel() {}
std::vector<paddle::ps::Region> regions;
for (auto& t : _param_config.dense_variable_name[table_id]) {
Variable* var = root_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(local_random_engine()) * init_range;
}
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto push_status = _pslib_ptr->_worker_ptr->push_dense_param( // todo SaveModel
regions.data(), regions.size(), table_id); void AsyncExecutor::SaveModel(const std::string& path) {}
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push dense param failed, status[" << status << "]";
exit(-1);
}
}
}
void AsyncExecutor::SaveModel(const std::string& path) {
auto ret = _pslib_ptr->_worker_ptr->flush();
ret.wait();
ret = _pslib_ptr->_worker_ptr->save(path, 0);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0
LOG(FATAL) << "save model failed";
exit(-1);
}
}
void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
if (mode == "mpi") {
DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;
param.threshold = 1;
param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_;
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread =
std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
}
}
#endif
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
...@@ -256,14 +88,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -256,14 +88,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc); &data_feed_desc);
actual_thread_num = thread_num; actual_thread_num_ = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty"); PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
if (actual_thread_num > file_cnt) { if (actual_thread_num_ > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing thread_num = " << file_cnt; << ". Changing thread_num = " << file_cnt;
actual_thread_num = file_cnt; actual_thread_num_ = file_cnt;
} }
/* /*
...@@ -279,12 +111,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -279,12 +111,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
*/ */
// todo: should be factory method for creating datafeed // todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers; std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); /*
PrepareReaders(readers, actual_thread_num_, data_feed_desc, filelist);
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
PrepareDenseThread(mode); PrepareDenseThread(mode);
#endif #endif
*/
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers; std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num); workers.resize(actual_thread_num_);
for (auto& worker : workers) { for (auto& worker : workers) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") { if (mode == "mpi") {
...@@ -298,13 +132,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -298,13 +132,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
} }
// prepare thread resource here // prepare thread resource here
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { /*
for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
CreateThreads(workers[thidx].get(), main_program, readers[thidx], CreateThreads(workers[thidx].get(), main_program, readers[thidx],
fetch_var_names, root_scope_, thidx, debug); fetch_var_names, root_scope_, thidx, debug);
} }
*/
// start executing ops in multiple threads // start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
if (debug) { if (debug) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer, threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer,
workers[thidx].get())); workers[thidx].get()));
...@@ -317,15 +153,19 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -317,15 +153,19 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
} }
// TODO(guru4elephant): we don't need this
/*
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") { if (mode == "mpi") {
_pull_dense_thread->stop(); _pull_dense_thread->stop();
} }
#endif #endif
*/
VLOG(3) << "start to run from files in async_executor";
VLOG(3) << "Drop current scope kids";
root_scope_->DropKids(); root_scope_->DropKids();
return; return;
} }
} // einit_modelnd namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -25,8 +25,10 @@ limitations under the License. */ ...@@ -25,8 +25,10 @@ limitations under the License. */
#include <typeinfo> #include <typeinfo>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_thread_worker.h" #include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -65,9 +67,10 @@ class AsyncExecutor { ...@@ -65,9 +67,10 @@ class AsyncExecutor {
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_var_names,
const std::string& mode, const bool debug = false); const std::string& mode, const bool debug);
#ifdef PADDLE_WITH_PSLIB
// TODO(guru4elephant): make init server decoupled from executor
void InitServer(const std::string& dist_desc, int index); void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc, void InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, int node_num, const std::vector<uint64_t>& host_sign_list, int node_num,
...@@ -77,31 +80,14 @@ class AsyncExecutor { ...@@ -77,31 +80,14 @@ class AsyncExecutor {
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num); void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
void InitModel(); void InitModel();
void SaveModel(const std::string& path); void SaveModel(const std::string& path);
void InitParamConfig();
#endif
private:
void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index,
const bool debug);
#ifdef PADDLE_WITH_PSLIB
void PrepareDenseThread(const std::string& mode);
#endif
public: public:
#ifdef PADDLE_WITH_PSLIB std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread;
AsyncWorkerParamConfig _param_config;
#endif
Scope* root_scope_; Scope* root_scope_;
platform::Place place_; platform::Place place_;
private: private:
int actual_thread_num; int actual_thread_num_;
}; };
} // namespace framework } // namespace framework
......
...@@ -33,6 +33,14 @@ class BlockingQueue { ...@@ -33,6 +33,14 @@ class BlockingQueue {
cv_.notify_one(); cv_.notify_one();
} }
void Push(T &&item) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(std::move(item));
}
cv_.notify_one();
}
template <typename U> template <typename U>
void Extend(const U &items) { void Extend(const U &items) {
{ {
...@@ -44,6 +52,17 @@ class BlockingQueue { ...@@ -44,6 +52,17 @@ class BlockingQueue {
cv_.notify_all(); cv_.notify_all();
} }
template <typename U>
void Extend(U &&items) {
{
std::lock_guard<std::mutex> g(mutex_);
for (auto &item : items) {
q_.emplace_back(std::move(item));
}
}
cv_.notify_all();
}
std::deque<T> PopAll(size_t ms, bool *timeout) { std::deque<T> PopAll(size_t ms, bool *timeout) {
auto time = auto time =
std::chrono::system_clock::now() + std::chrono::milliseconds(ms); std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
...@@ -64,6 +83,18 @@ class BlockingQueue { ...@@ -64,6 +83,18 @@ class BlockingQueue {
return rc; return rc;
} }
void Pop(T *t) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); });
*t = std::move(q_.front());
q_.pop_front();
}
size_t Size() {
std::lock_guard<std::mutex> lock(mutex_);
return q_.size();
}
private: private:
std::mutex mutex_; std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
......
此差异已折叠。
...@@ -15,17 +15,23 @@ limitations under the License. */ ...@@ -15,17 +15,23 @@ limitations under the License. */
#pragma once #pragma once
#include <fstream> #include <fstream>
#include <future> // NOLINT
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <sstream>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -48,7 +54,10 @@ namespace framework { ...@@ -48,7 +54,10 @@ namespace framework {
// } // }
class DataFeed { class DataFeed {
public: public:
DataFeed() {} DataFeed() {
mutex_for_pick_file_ = nullptr;
file_idx_ = nullptr;
}
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0; virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
...@@ -59,6 +68,7 @@ class DataFeed { ...@@ -59,6 +68,7 @@ class DataFeed {
// Otherwise, Init() function will init finish_set_filelist_ flag. // Otherwise, Init() function will init finish_set_filelist_ flag.
virtual bool SetFileList(const std::vector<std::string>& files); virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0; virtual bool Start() = 0;
// The trainer calls the Next() function, and the DataFeed will load a new // The trainer calls the Next() function, and the DataFeed will load a new
// batch to the feed_vec. The return value of this function is the batch // batch to the feed_vec. The return value of this function is the batch
// size of the current batch. // size of the current batch.
...@@ -74,6 +84,38 @@ class DataFeed { ...@@ -74,6 +84,38 @@ class DataFeed {
// This function is used for binding feed_vec memory // This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name); virtual void AddFeedVar(Variable* var, const std::string& name);
// This function will do nothing at default
virtual void SetMemoryData(void* memory_data) {}
// This function will do nothing at default
virtual void SetMemoryDataMutex(std::mutex* mutex) {}
// This function will do nothing at default
virtual void SetThreadId(int thread_id) {}
// This function will do nothing at default
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) {}
// This function will do nothing at default
virtual void SetFleetSendBatchSize(int64_t size) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
virtual void LocalShuffle() {
PADDLE_THROW("This function(LocalShuffle) is not implemented.");
}
virtual void GlobalShuffle() {
PADDLE_THROW("This function(GlobalShuffle) is not implemented.");
}
// This function will do nothing at default
virtual void FillMemoryDataToChannel() {}
// This function will do nothing at default
virtual void FillChannelToMemoryData() {}
// This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) {}
protected: protected:
// The following three functions are used to check if it is executed in this // The following three functions are used to check if it is executed in this
// order: // order:
...@@ -87,9 +129,9 @@ class DataFeed { ...@@ -87,9 +129,9 @@ class DataFeed {
// safe). // safe).
virtual bool PickOneFile(std::string* filename); virtual bool PickOneFile(std::string* filename);
static std::vector<std::string> filelist_; std::vector<std::string> filelist_;
static size_t file_idx_; size_t* file_idx_;
static std::mutex mutex_for_pick_file_; std::mutex* mutex_for_pick_file_;
// the alias of used slots, and its order is determined by // the alias of used slots, and its order is determined by
// data_feed_desc(proto object) // data_feed_desc(proto object)
...@@ -100,6 +142,7 @@ class DataFeed { ...@@ -100,6 +142,7 @@ class DataFeed {
// object) // object)
std::vector<std::string> all_slots_; std::vector<std::string> all_slots_;
std::vector<std::string> all_slots_type_; std::vector<std::string> all_slots_type_;
std::vector<std::vector<int>> use_slots_shape_;
std::vector<int> std::vector<int>
use_slots_index_; // -1: not used; >=0: the index of use_slots_ use_slots_index_; // -1: not used; >=0: the index of use_slots_
...@@ -112,8 +155,9 @@ class DataFeed { ...@@ -112,8 +155,9 @@ class DataFeed {
int batch_size_; int batch_size_;
bool finish_init_; bool finish_init_;
static bool finish_set_filelist_; bool finish_set_filelist_;
bool finish_start_; bool finish_start_;
std::string pipe_command_;
}; };
// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds. // PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
...@@ -136,6 +180,7 @@ class PrivateQueueDataFeed : public DataFeed { ...@@ -136,6 +180,7 @@ class PrivateQueueDataFeed : public DataFeed {
virtual void SetQueueSize(int queue_size); virtual void SetQueueSize(int queue_size);
// The reading and parsing method called in the ReadThread. // The reading and parsing method called in the ReadThread.
virtual bool ParseOneInstance(T* instance) = 0; virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
// This function is used to put instance to vec_ins // This function is used to put instance to vec_ins
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
int index) = 0; int index) = 0;
...@@ -150,11 +195,59 @@ class PrivateQueueDataFeed : public DataFeed { ...@@ -150,11 +195,59 @@ class PrivateQueueDataFeed : public DataFeed {
// ifstream one line and one line parse: 6034 ms // ifstream one line and one line parse: 6034 ms
// fread one buffer and one buffer parse: 7097 ms // fread one buffer and one buffer parse: 7097 ms
std::ifstream file_; std::ifstream file_;
std::shared_ptr<FILE> fp_;
size_t queue_size_; size_t queue_size_;
string::LineFileReader reader_;
// The queue for store parsed data // The queue for store parsed data
std::unique_ptr<paddle::operators::reader::BlockingQueue<T>> queue_; std::unique_ptr<paddle::operators::reader::BlockingQueue<T>> queue_;
}; };
template <typename T>
class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
InMemoryDataFeed();
virtual ~InMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
virtual void SetMemoryData(void* memory_data);
virtual void SetMemoryDataMutex(std::mutex* mutex);
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size);
virtual void PutInsToChannel(const std::string& ins_str);
virtual void FillMemoryDataToChannel();
virtual void FillChannelToMemoryData();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const std::vector<T*>& ins, std::string* str) = 0;
virtual void DeserializeIns(std::vector<T>* ins, const std::string& str) = 0;
virtual std::pair<int64_t, int64_t> GetMemoryDataInterval();
int thread_id_;
int thread_num_;
int trainer_num_;
uint32_t rand_seed;
std::vector<T>* memory_data_;
std::mutex* mutex_for_update_memory_data_;
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int cur_channel_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed // This class define the data type of instance(ins_vec) in MultiSlotDataFeed
class MultiSlotType { class MultiSlotType {
public: public:
...@@ -176,6 +269,7 @@ class MultiSlotType { ...@@ -176,6 +269,7 @@ class MultiSlotType {
offset_[0] = 0; offset_[0] = 0;
} }
const std::vector<size_t>& GetOffset() const { return offset_; } const std::vector<size_t>& GetOffset() const { return offset_; }
std::vector<size_t>& MutableOffset() { return offset_; }
void AddValue(const float v) { void AddValue(const float v) {
CheckFloat(); CheckFloat();
float_feasign_.push_back(v); float_feasign_.push_back(v);
...@@ -198,8 +292,33 @@ class MultiSlotType { ...@@ -198,8 +292,33 @@ class MultiSlotType {
} }
} }
const std::vector<float>& GetFloatData() const { return float_feasign_; } const std::vector<float>& GetFloatData() const { return float_feasign_; }
std::vector<float>& MutableFloatData() { return float_feasign_; }
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; } const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
const std::string& GetType() const { return type_; } const std::string& GetType() const { return type_; }
std::string& MutableType() { return type_; }
std::string DebugString() {
std::stringstream ss;
ss << "\ntype: " << type_ << "\n";
ss << "offset: ";
ss << "[";
for (const size_t& i : offset_) {
ss << offset_[i] << ",";
}
ss << "]\ndata: [";
if (type_[0] == 'f') {
for (const float& i : float_feasign_) {
ss << i << ",";
}
} else {
for (const uint64_t& i : uint64_feasign_) {
ss << i << ",";
}
}
ss << "]\n";
return ss.str();
}
private: private:
void CheckType(const std::string& type) const { void CheckType(const std::string& type) const {
...@@ -228,13 +347,37 @@ class MultiSlotDataFeed ...@@ -228,13 +347,37 @@ class MultiSlotDataFeed
virtual ~MultiSlotDataFeed() {} virtual ~MultiSlotDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc); virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual bool CheckFile(const char* filename); virtual bool CheckFile(const char* filename);
// virtual void ReadThread();
protected: protected:
virtual void ReadThread();
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins, virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance, const std::vector<MultiSlotType>& instance,
int index); int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance); virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec); virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
}; };
class MultiSlotInMemoryDataFeed
: public InMemoryDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance,
int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<std::vector<MultiSlotType>*>& ins,
std::string* str);
virtual void DeserializeIns(std::vector<std::vector<MultiSlotType>>* ins,
const std::string& str);
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,6 +19,7 @@ message Slot { ...@@ -19,6 +19,7 @@ message Slot {
required string type = 2; required string type = 2;
optional bool is_dense = 3 [ default = false ]; optional bool is_dense = 3 [ default = false ];
optional bool is_used = 4 [ default = false ]; optional bool is_used = 4 [ default = false ];
repeated int32 shape = 5; // we can define N-D Tensor
} }
message MultiSlotDesc { repeated Slot slots = 1; } message MultiSlotDesc { repeated Slot slots = 1; }
...@@ -27,4 +28,6 @@ message DataFeedDesc { ...@@ -27,4 +28,6 @@ message DataFeedDesc {
optional string name = 1; optional string name = 1;
optional int32 batch_size = 2 [ default = 32 ]; optional int32 batch_size = 2 [ default = 32 ];
optional MultiSlotDesc multi_slot_desc = 3; optional MultiSlotDesc multi_slot_desc = 3;
optional string pipe_command = 4;
optional int32 thread_num = 5;
} }
...@@ -54,11 +54,15 @@ std::string DataFeedFactory::DataFeedTypeList() { ...@@ -54,11 +54,15 @@ std::string DataFeedFactory::DataFeedTypeList() {
std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) { std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) { if (g_data_feed_map.count(data_feed_class) < 1) {
LOG(WARNING) << "Your DataFeed " << data_feed_class
<< "is not supported currently";
LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList();
exit(-1); exit(-1);
} }
return g_data_feed_map[data_feed_class](); return g_data_feed_map[data_feed_class]();
} }
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -324,7 +324,7 @@ TEST(DataFeed, MultiSlotUnitTest) { ...@@ -324,7 +324,7 @@ TEST(DataFeed, MultiSlotUnitTest) {
load_datafeed_param_from_file(protofile); load_datafeed_param_from_file(protofile);
std::vector<MultiTypeSet> reader_elem_set; std::vector<MultiTypeSet> reader_elem_set;
std::vector<MultiTypeSet> file_elem_set; std::vector<MultiTypeSet> file_elem_set;
GetElemSetFromReader(&reader_elem_set, data_feed_desc, filelist, 4); // GetElemSetFromReader(&reader_elem_set, data_feed_desc, filelist, 4);
GetElemSetFromFile(&file_elem_set, data_feed_desc, filelist); // GetElemSetFromFile(&file_elem_set, data_feed_desc, filelist);
CheckIsUnorderedSame(reader_elem_set, file_elem_set); // CheckIsUnorderedSame(reader_elem_set, file_elem_set);
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/data_set.h"
#include <random>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/timer.h"
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
namespace paddle {
namespace framework {
// constructor
template <typename T>
DatasetImpl<T>::DatasetImpl() {
thread_num_ = 1;
trainer_num_ = 1;
file_idx_ = 0;
}
// set filelist, file_idx_ will reset to zero.
template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist;
file_idx_ = 0;
}
// set expect thread num. actually it may change
template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) {
VLOG(3) << "SetThreadNum thread_num=" << thread_num;
thread_num_ = thread_num;
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerNum
template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
// should inform reader of trainer_num directly
for (auto reader : readers_) {
reader->SetTrainerNum(trainer_num);
}
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetFleetSendBatchSize
template <typename T>
void DatasetImpl<T>::SetFleetSendBatchSize(int64_t size) {
fleet_send_batch_size_ = size;
for (auto reader : readers_) {
reader->SetFleetSendBatchSize(size);
}
}
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
fs_name_ = fs_name;
fs_ugi_ = fs_ugi;
std::string cmd = std::string("hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
paddle::framework::hdfs_set_command(cmd);
}
template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc_);
}
// readers_.size() may not be equal to thread_num_,
// it changes when filelist_.size() < thread_num_
template <typename T>
std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
DatasetImpl<T>::GetReaders() {
return readers_;
}
// if sent message between workers, should first call this function
template <typename T>
void DatasetImpl<T>::RegisterClientToClientMsgHandler() {
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
VLOG(3) << "RegisterClientToClientMsgHandler done";
}
// load data into memory, Dataset hold this memory,
// which will later be fed into readers' channel
template <typename T>
void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> load_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
}
for (std::thread& t : load_threads) {
t.join();
}
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
<< ", memory data size=" << memory_data_.size()
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
// release memory data
template <typename T>
void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}
// do local shuffle
template <typename T>
void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::LocalShuffle, readers_[i].get()));
}
for (std::thread& t : local_shuffle_threads) {
t.join();
}
std::vector<T>().swap(memory_data_);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::GlobalShuffle, readers_[i].get()));
}
for (std::thread& t : global_shuffle_threads) {
t.join();
}
std::vector<T>().swap(memory_data_);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0";
int file_cnt = filelist_.size();
int memory_data_size = memory_data_.size();
if (memory_data_size != 0 && thread_num_ > memory_data_size) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", memory data size = " << memory_data_size
<< ". Changing Dataset thread num = " << memory_data_size;
thread_num_ = memory_data_size;
} else if (file_cnt != 0 && thread_num_ > file_cnt) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing Dataset thread num = " << file_cnt;
thread_num_ = file_cnt;
}
VLOG(3) << "thread_num in Readers: " << thread_num_;
VLOG(3) << "readers size: " << readers_.size();
VLOG(3) << "Filelist size in readers: " << filelist_.size();
if (readers_.size() != 0) {
return;
}
VLOG(3) << "data feed class name: " << data_feed_desc_.name();
for (int i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_);
readers_.back()->SetMemoryData(&memory_data_);
readers_.back()->SetMemoryDataMutex(&mutex_for_update_memory_data_);
readers_.back()->SetThreadId(i);
readers_.back()->SetThreadNum(thread_num_);
readers_.back()->SetTrainerNum(trainer_num_);
readers_.back()->SetFileListMutex(&mutex_for_pick_file_);
readers_.back()->SetFileListIndex(&file_idx_);
readers_.back()->SetFileList(filelist_);
}
}
template <typename T>
void DatasetImpl<T>::DestroyReaders() {
VLOG(3) << "Calling DestroyReaders()";
// clear memory_data_ before fill it
// because if LoadIntoMemory but no Shuffle,
// memory_data_ has empty data which has been std::move to channel
if (memory_data_.size() != 0) {
std::vector<T>().swap(memory_data_);
}
std::vector<std::thread> fill_threads;
for (int i = 0; i < thread_num_; ++i) {
fill_threads.push_back(
std::thread(&paddle::framework::DataFeed::FillChannelToMemoryData,
readers_[i].get()));
}
for (std::thread& t : fill_threads) {
t.join();
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "readers size: " << readers_.size();
// if memory_data_ is empty, which means it's not InMemory mode,
// so the next epoch should read all data again
if (memory_data_.size() == 0) {
file_idx_ = 0;
}
}
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
#ifdef _LINUX
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << msg.length();
auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = rand_r(&rand_seed) % thread_num_;
VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg);
#endif
return 0;
}
// explicit instantiation
template class DatasetImpl<std::vector<MultiSlotType>>;
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
// Dataset is a abstract class, which defines user interfaces
// Example Usage:
// Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
// dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
// dataset->SetThreadNum(1)
// dataset->CreateReaders();
// dataset->SetDataFeedDesc(your_data_feed_desc);
// dataset->LoadIntoMemory();
// dataset->SetTrainerNum(2);
// dataset->GlobalShuffle();
class Dataset {
public:
Dataset() {}
virtual ~Dataset() {}
// set file list
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
// set readers' num
virtual void SetThreadNum(int thread_num) = 0;
// set workers' num
virtual void SetTrainerNum(int trainer_num) = 0;
// set fleet send batch size
virtual void SetFleetSendBatchSize(int64_t size) = 0;
// set fs name and ugi
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
// set data fedd desc, which contains:
// data feed name, batch size, slots
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// get file list
virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num
virtual int GetThreadNum() = 0;
// get worker num
virtual int GetTrainerNum() = 0;
// get fleet send batch size
virtual int64_t GetFleetSendBatchSize() = 0;
// get hdfs config
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
// get data fedd desc
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
// get readers, the reader num depend both on thread num
// and filelist size
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders() = 0;
// register message handler between workers
virtual void RegisterClientToClientMsgHandler() = 0;
// load all data into memory
virtual void LoadIntoMemory() = 0;
// release all memory data
virtual void ReleaseMemory() = 0;
// local shuffle data
virtual void LocalShuffle() = 0;
// global shuffle data
virtual void GlobalShuffle() = 0;
// create readers
virtual void CreateReaders() = 0;
// destroy readers
virtual void DestroyReaders() = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
};
// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
template <typename T>
class DatasetImpl : public Dataset {
public:
DatasetImpl();
virtual ~DatasetImpl() {}
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size);
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
}
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders();
virtual void RegisterClientToClientMsgHandler();
virtual void LoadIntoMemory();
virtual void ReleaseMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
virtual void CreateReaders();
virtual void DestroyReaders();
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<T> memory_data_;
std::mutex mutex_for_update_memory_data_;
int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
int trainer_num_;
std::vector<std::string> filelist_;
size_t file_idx_;
std::mutex mutex_for_pick_file_;
std::string fs_name_;
std::string fs_ugi_;
unsigned int rand_seed;
int64_t fleet_send_batch_size_;
};
// use std::vector<MultiSlotType> as data type
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
public:
MultiSlotDataset() {}
virtual ~MultiSlotDataset() {}
};
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/dataset_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)();
typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap;
datasetMap g_dataset_map;
#define REGISTER_DATASET_CLASS(dataset_class) \
namespace { \
std::shared_ptr<Dataset> Creator_##dataset_class() { \
return std::shared_ptr<Dataset>(new dataset_class); \
} \
class __Registerer_##dataset_class { \
public: \
__Registerer_##dataset_class() { \
g_dataset_map[#dataset_class] = &Creator_##dataset_class; \
} \
}; \
__Registerer_##dataset_class g_registerer_##dataset_class; \
} // namespace
std::string DatasetFactory::DatasetTypeList() {
std::string dataset_types;
for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end(); ++iter) {
if (iter != g_dataset_map.begin()) {
dataset_types += ", ";
}
dataset_types += iter->first;
}
return dataset_types;
}
std::shared_ptr<Dataset> DatasetFactory::CreateDataset(
std::string dataset_class) {
if (g_dataset_map.count(dataset_class) < 1) {
LOG(WARNING) << "Your Dataset " << dataset_class
<< "is not supported currently";
LOG(WARNING) << "Supported Dataset: " << DatasetTypeList();
exit(-1);
}
return g_dataset_map[dataset_class]();
}
REGISTER_DATASET_CLASS(MultiSlotDataset);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
class DatasetFactory {
public:
static std::string DatasetTypeList();
static std::shared_ptr<Dataset> CreateDataset(std::string dataset_class);
};
} // namespace framework
} // namespace paddle
...@@ -25,8 +25,12 @@ if(WITH_DISTRIBUTE) ...@@ -25,8 +25,12 @@ if(WITH_DISTRIBUTE)
endif() endif()
if(WITH_GPU) if(WITH_GPU)
set(dgc_deps "")
if(NOT WIN32)
set(dgc_deps dgc)
endif()
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor dgc) dynload_cuda variable_visitor ${dgc_deps})
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
...@@ -92,6 +96,12 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS ...@@ -92,6 +96,12 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS
cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor) cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor)
set(ASYNC_SSA_GRAPH_EXECUTOR_DEPS threaded_ssa_graph_executor)
if(WITH_DISTRIBUTE)
list(APPEND ASYNC_SSA_GRAPH_EXECUTOR_DEPS communicator)
endif()
cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS ${ASYNC_SSA_GRAPH_EXECUTOR_DEPS})
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle) device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
......
...@@ -13,125 +13,186 @@ ...@@ -13,125 +13,186 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <memory> #include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
#include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h" #include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
VarHandle* GetValidInput(const OpHandleBase* a) { class AllReduceDepsPass : public ir::Pass {
for (auto p : a->Inputs()) { protected:
VarHandle* b = dynamic_cast<VarHandle*>(p); void ApplyImpl(ir::Graph* graph) const override {
if (b) { std::vector<AllReduceOpHandle*> all_reduce_op_handles =
return b; GetSortedAllReduceOps(*graph);
for (size_t i = 1; i < all_reduce_op_handles.size(); ++i) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
all_reduce_op_handles[i - 1]->AddOutput(dep_var);
all_reduce_op_handles[i]->AddInput(dep_var);
} }
}
return nullptr; if (VLOG_IS_ON(10)) {
} DebugString(*graph, all_reduce_op_handles);
void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// get vars order
int order = 0;
std::unordered_map<std::string, int> vars;
// TODO(gongwb): use graph topology sort to find the order of operators.
// Note that must assert topology sort is stable
auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
for (auto* op_desc : ops) {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(op_desc->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
auto backward_vars =
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) {
for (auto& v : o_it.second) { // values
vars[v] = order;
VLOG(1) << "in all_reduce_deps_pass:" << v;
}
}
order++;
} catch (boost::bad_get e) {
} }
} }
std::vector<OpHandleBase*> dist_ops; std::vector<AllReduceOpHandle*> GetSortedAllReduceOps(
// get allreduce ops. const ir::Graph& graph) const {
for (auto& op : graph_ops) { std::vector<AllReduceOpHandle*> all_reduce_op_handles;
// FIXME(gongwb):add broad cast. std::unordered_map<OpHandleBase*, size_t> pending_ops;
if (op->Name() == "all_reduce" || op->Name() == "reduce") { std::unordered_set<OpHandleBase*> ready_ops;
dist_ops.push_back(op); std::unordered_set<OpHandleBase*> next_ready_ops;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(graph);
size_t num_of_ops = op_handles.size();
for (OpHandleBase* op : op_handles) {
size_t not_ready_vars = op->NotReadyInputSize();
if (not_ready_vars) {
pending_ops.insert({op, not_ready_vars});
} else {
ready_ops.insert(op);
}
} }
}
VLOG(10) << "dist_ops size:" << dist_ops.size()
<< ", outputs size:" << vars.size() << ", ops size:" << ops.size();
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
OpHandleBase* op2) {
VarHandle* i0 = dynamic_cast<VarHandle*>(GetValidInput(op1));
VarHandle* i1 = dynamic_cast<VarHandle*>(GetValidInput(op2));
PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error",
op1->DebugString(), op2->DebugString());
auto l_it = vars.find(i0->name()); GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles);
auto r_it = vars.find(i1->name());
size_t has_run_ops = ready_ops.size();
PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(), while (has_run_ops != num_of_ops) {
"can't find var's name %s and %s in opdesc", i0->name(), for (auto* op : ready_ops) {
i1->name()); for (auto& ready_var : op->Outputs()) {
for (auto* pend_op : ready_var->PendingOps()) {
if (l_it->second < r_it->second) return true; auto& deps = --pending_ops[pend_op];
if (deps == 0) {
next_ready_ops.insert(pend_op);
}
}
}
}
if (l_it->second == r_it->second) { PADDLE_ENFORCE_NE(next_ready_ops.size(), 0, "There maybe have a cycle.");
return i0->name() < i1->name(); ready_ops.clear();
std::swap(ready_ops, next_ready_ops);
GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles);
has_run_ops += ready_ops.size();
} }
return all_reduce_op_handles;
}
return false; void GetSortedAllReduceOps(
}); const std::unordered_set<OpHandleBase*>& ready_ops,
std::vector<AllReduceOpHandle*>* all_reduce_op_handles) const {
// add dependency. std::vector<AllReduceOpHandle*> current_all_reduce_op_handles;
auto& sorted_ops = dist_ops; for (auto& op_handle : ready_ops) {
for (size_t i = 1; i < sorted_ops.size(); ++i) { auto all_reduce_op_handle = dynamic_cast<AllReduceOpHandle*>(op_handle);
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar()); if (all_reduce_op_handle) {
current_all_reduce_op_handles.emplace_back(all_reduce_op_handle);
auto* pre_op = sorted_ops[i - 1]; }
auto* op = sorted_ops[i]; }
pre_op->AddOutput(dep_var);
op->AddInput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
VLOG(10) << "add all_reduce sequential dependencies between " << pre_op // NOTE(zcd): For distributed training, it is important to keep the order of
<< " and " << op; // allReduce on each node consistent. Otherwise, hang may occur.
// Sort the current_all_reduce_op_handles according to the name of input.
sort(current_all_reduce_op_handles.begin(),
current_all_reduce_op_handles.end(),
[](const AllReduceOpHandle* left,
const AllReduceOpHandle* right) -> bool {
auto left_in_vars = DynamicCast<VarHandle>(left->Inputs());
auto right_in_vars = DynamicCast<VarHandle>(right->Inputs());
PADDLE_ENFORCE_GT(left_in_vars.size(), 0);
PADDLE_ENFORCE_EQ(left_in_vars.size(), right_in_vars.size());
return left_in_vars[0]->Name() > right_in_vars[0]->Name();
});
all_reduce_op_handles->insert(all_reduce_op_handles->end(),
current_all_reduce_op_handles.begin(),
current_all_reduce_op_handles.end());
}
VLOG(10) << "pre_op:" << pre_op->DebugString() void DebugString(
<< ", op:" << op->DebugString(); const ir::Graph& graph,
const std::vector<AllReduceOpHandle*>& all_reduce_op_handles) const {
// get vars order
std::map<int, std::vector<std::string>> vars =
GetSoredGradientsFromStaleProgram(graph);
std::stringstream out;
size_t grads_of_stale_program = 0;
out << "Get Order From kStaleProgramOpDescs: ";
for (auto& var : vars) {
out << "Order " << var.first << " [";
for (auto& var_name : var.second) {
out << var_name << ", ";
++grads_of_stale_program;
}
out << "], ";
}
VLOG(10) << out.str();
std::stringstream out2;
out2 << "Get Order From Topological order: ";
for (auto& op : all_reduce_op_handles) {
bool find_valid_input = false;
for (auto& in_var : op->Inputs()) {
if (dynamic_cast<VarHandle*>(in_var)) {
out2 << in_var->Name() << ", ";
find_valid_input = true;
break;
}
}
PADDLE_ENFORCE(find_valid_input, "Doesn't find valid input.");
}
VLOG(10) << out2.str();
if (grads_of_stale_program != all_reduce_op_handles.size()) {
VLOG(10)
<< "The gradients number of stale program and graph is not equal.";
}
} }
}
std::map<int, std::vector<std::string>> GetSoredGradientsFromStaleProgram(
const ir::Graph& graph) const {
std::map<int, std::vector<std::string>> vars;
auto ops = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
int order = 0;
for (auto* op_desc : ops) {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(op_desc->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
auto backward_vars =
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (backward_vars.empty()) continue;
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 1; i < backward_vars.size(); i += 2) {
vars[order].emplace_back(backward_vars[i]);
VLOG(1) << "get parameter and gradient: " << backward_vars[i - 1]
<< ", " << backward_vars[i];
}
order++;
} catch (boost::bad_get e) {
}
}
return vars;
}
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
// asynchronous nccl allreduce or synchronous issue: // asynchronous nccl allreduce or synchronous issue:
// https://github.com/PaddlePaddle/Paddle/issues/15049 // https://github.com/PaddlePaddle/Paddle/issues/15049
DEFINE_bool( DEFINE_bool(
sync_nccl_allreduce, false, sync_nccl_allreduce, true,
"If set true, will call `cudaStreamSynchronize(nccl_stream)`" "If set true, will call `cudaStreamSynchronize(nccl_stream)`"
"after allreduce, this mode can get better performance in some scenarios."); "after allreduce, this mode can get better performance in some scenarios.");
...@@ -53,6 +53,10 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -53,6 +53,10 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p)); this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
} }
} }
// TODO(gongwb) :polish them!
if (is_encoded) {
VLOG(1) << "Use dgc allreduce mode";
}
} }
#else #else
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
...@@ -86,7 +90,7 @@ void AllReduceOpHandle::RunImplEncoded() { ...@@ -86,7 +90,7 @@ void AllReduceOpHandle::RunImplEncoded() {
paddle::framework::GradOriginalVarName(in_var_handles[i]->name()); paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded; auto encode_var_name = original_name + g_dgc_encoded;
auto *in_var = local_scope->FindVar(encode_var_name); auto *in_var = local_scope->FindVar(encode_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var, "%s should not be null", encode_var_name);
auto &in = in_var->Get<LoDTensor>(); auto &in = in_var->Get<LoDTensor>();
ins.emplace_back(&in); ins.emplace_back(&in);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const;
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const;
void RecordParamsAndGrads(ir::Node *node, ParamsAndGrads *params_grads) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const;
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/variable_helper.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/communicator.h"
#endif
namespace paddle {
namespace framework {
namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) {
VLOG(3) << "NewTempScopeAndInitVars";
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
// get RpcContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto i = 0; i < graphs.size(); ++i) {
std::vector<ir::Node *> nodes_to_delete;
for (auto &node : graphs[i]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections"));
send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
auto recv_var_name = node->Op()->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {});
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
}
}
}
}
// init communicator here
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
}
#endif
}
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL
: strategy_.num_threads_ / places_.size();
VLOG(1) << "set num_threads: " << strategy_.num_threads_
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
}
for (auto &node : graphs_[0]->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos_.emplace_back();
var_infos_.back().name_ = node->Var()->Name();
var_infos_.back().type_ = node->Var()->GetType();
var_infos_.back().persistable_ = node->Var()->Persistable();
}
}
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope);
}
ProcessGraph(graphs_, local_scopes_[0]);
}
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
for (size_t i = 1; i < places_.size(); ++i) {
auto call = [this, i]() -> void {
VLOG(3) << "start off python thread " << i;
try {
while (true) {
executors_[i]->Run({});
}
} catch (...) {
exception_holder_.Catch(std::current_exception());
VLOG(3) << "get exception type = " << exception_holder_.Type();
}
VLOG(3) << "thread " << i << " exited!";
};
run_futures_.emplace_back(pool_->enqueue(std::move(call)));
}
}
void AsyncSSAGraphExecutor::HandleException() {
if (exception_holder_.IsCaught()) {
for (auto &f : run_futures_) {
VLOG(3) << "wait future";
f.wait();
}
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
run_futures_.clear();
exception_holder_.ReThrow();
}
}
FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// init once
if (run_futures_.size() == 0 && places_.size() > 1) {
exception_holder_.Clear();
StartOffPythonTrainLoop();
}
if (places_.size() == 1) {
exception_holder_.Clear();
} else {
HandleException();
}
FeedFetchList fetch_data;
fetch_data.reserve(fetch_tensors.size());
try {
fetch_data = executors_[0]->Run(fetch_tensors);
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
HandleException();
FeedFetchList ret;
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
}
return ret;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
namespace paddle {
namespace framework {
namespace details {
struct VarInfo {
std::string name_;
proto::VarType::Type type_;
bool persistable_;
};
class AsyncSSAGraphExecutor : public SSAGraphExecutor {
public:
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<ir::Graph *> graphs);
~AsyncSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
void StartOffPythonTrainLoop();
void HandleException();
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
std::vector<ir::Graph *> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_;
std::vector<std::future<void>> run_futures_;
std::vector<VarInfo> var_infos_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -142,6 +142,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -142,6 +142,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("memory_optimize_pass"); AppendPass("memory_optimize_pass");
} }
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) {
VLOG(10) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass");
}
if (strategy_.cache_expected_kernel_) {
VLOG(10) << "Add expected_kernel_cache_pass";
AppendPass("expected_kernel_cache_pass");
}
AppendMultiDevPass(strategy_); AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) { if (strategy_.fuse_all_reduce_ops_) {
...@@ -163,14 +176,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -163,14 +176,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"graph_printer", new details::GraphvizSSAGraphPrinter); "graph_printer", new details::GraphvizSSAGraphPrinter);
} }
// Verify that the graph is correct for multi-device executor. // experimental shows that the program will be faster if append
AppendPass("multi_devices_check_pass"); // all_reduce_deps_pass here.
if (!strategy_.enable_parallel_graph_ &&
if (VLOG_IS_ON(2)) { (SeqOnlyAllReduceOps(strategy_) ||
AppendPass("all_reduce_deps_pass"); strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
}
if (SeqOnlyAllReduceOps(strategy_)) {
VLOG(10) << "Add all_reduce_deps_pass"; VLOG(10) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass"); AppendPass("all_reduce_deps_pass");
} }
...@@ -179,13 +189,20 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -179,13 +189,20 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
VLOG(10) << "Add modify_op_lock_and_record_event_pass"; VLOG(10) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass"); AppendPass("modify_op_lock_and_record_event_pass");
} }
// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
} }
// Convert graph to run on multi-devices. // Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) { void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass = nullptr; ir::Pass *multi_devices_pass = nullptr;
if (strategy.is_distribution_) {
VLOG(10) << "Add dist_multi_devices_pass"; if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
VLOG(10)
<< "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
...@@ -234,10 +251,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -234,10 +251,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
VLOG(3) << "apply all passes";
// Create a default one if not finalized by user. // Create a default one if not finalized by user.
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
VLOG(3) << "BuildStrategy::Apply pass:" << pass->Type();
if (IsMultiDevPass(pass->Type())) { if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
...@@ -293,6 +312,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -293,6 +312,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
graph = pass->Apply(graph); graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type(); VLOG(3) << "Finish Apply Pass " << pass->Type();
} }
VLOG(3) << "All Passes Applied";
return graph; return graph;
} }
...@@ -321,3 +341,5 @@ USE_PASS(graph_to_program_pass); ...@@ -321,3 +341,5 @@ USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass); USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass);
...@@ -83,25 +83,33 @@ struct BuildStrategy { ...@@ -83,25 +83,33 @@ struct BuildStrategy {
bool sync_batch_norm_{false}; bool sync_batch_norm_{false};
bool memory_optimize_{true}; // FIXME(liuwei1031) disable memory_optimzie and enable_inplace in 1.4
// TODO(dzhwinter): // to open them by default, we need to solve the fetch variable issue
// make enable_inplace, memory_optimize_ bool memory_optimize_{false};
// memory_early_delete_ true by default
bool enable_inplace_{true}; bool enable_inplace_{false};
bool enable_sequential_execution_{false}; bool enable_sequential_execution_{false};
bool fuse_broadcast_op_{false}; // NOTE(zcd): In reduce mode, fusing broadcast ops may make the program
// faster. Because fusing broadcast OP equals delaying the execution of all
// broadcast Ops, in this case, all nccl streams are used only for reduce
// operations for a period of time.
bool fuse_broadcast_ops_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if // num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model. // it's distributed model.
bool is_distribution_{false}; bool is_distribution_{false};
bool async_mode_{false};
int num_trainers_{1}; int num_trainers_{1};
int trainer_id_{0}; int trainer_id_{0};
std::vector<std::string> trainers_endpoints_; std::vector<std::string> trainers_endpoints_;
bool remove_unnecessary_lock_{true}; bool remove_unnecessary_lock_{true};
bool cache_runtime_context_{false};
bool cache_expected_kernel_{true};
// NOTE: // NOTE:
// Before you add new options, think if it's a general strategy that works // Before you add new options, think if it's a general strategy that works
// with other strategy. If not, the strategy should be created through // with other strategy. If not, the strategy should be created through
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -64,6 +67,21 @@ class ExceptionHolder { ...@@ -64,6 +67,21 @@ class ExceptionHolder {
ClearImpl(); ClearImpl();
} }
std::string Type() {
std::lock_guard<std::mutex> lock(mu_);
switch (type_) {
case kNone:
return "None";
case kEnforceNotMet: {
return "EnforceNotMet";
}
case kEOF: {
return "EOF";
}
}
return "unknown";
}
private: private:
void ClearImpl() { void ClearImpl() {
exception_.reset(); exception_.reset();
......
...@@ -31,6 +31,8 @@ struct ExecutionStrategy { ...@@ -31,6 +31,8 @@ struct ExecutionStrategy {
size_t num_iteration_per_drop_scope_{1}; size_t num_iteration_per_drop_scope_{1};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false}; bool dry_run_{false};
size_t num_iteration_per_run_{1}; // only use with async_ssa_graph_executor
// and pyreader with data queue
}; };
} // namespace details } // namespace details
......
...@@ -56,6 +56,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -56,6 +56,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<FetchOpHandle *> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
...@@ -70,8 +71,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -70,8 +71,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name); auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program " "Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)"); "is not set to ParallelExecutor)",
var_name);
auto &vars = fetched_var_it->second; auto &vars = fetched_var_it->second;
...@@ -88,7 +90,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -88,7 +90,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
op->AddInput(var); op->AddInput(var);
} }
(*op_deps)[op] = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
ready_fetch_ops.emplace_back(op);
}
} }
size_t num_complete = 0; size_t num_complete = 0;
...@@ -97,7 +103,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -97,7 +103,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
for (auto op : bootstrap_ops_) { for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q); RunOpAsync(op_deps.get(), op, complete_q);
} }
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
while (num_complete != op_deps->size()) { while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop(); size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) { if (num_comp == -1UL) {
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -44,6 +44,7 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const { ...@@ -44,6 +44,7 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
} }
void FetchOpHandle::RunImpl() { void FetchOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated(platform::CPUPlace()); WaitInputVarGenerated(platform::CPUPlace());
tensors_.resize(inputs_.size()); tensors_.resize(inputs_.size());
...@@ -62,7 +63,8 @@ void FetchOpHandle::RunImpl() { ...@@ -62,7 +63,8 @@ void FetchOpHandle::RunImpl() {
auto &t = var->Get<framework::LoDTensor>(); auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) { if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TensorCopySync(t, cpu, &tensors_[i]); TensorCopy(t, cpu, *dev_ctxes_.at(t.place()), &tensors_[i]);
dev_ctxes_.at(t.place())->Wait();
#endif #endif
} else { } else {
tensors_[i].ShareDataWith(t); tensors_[i].ShareDataWith(t);
......
...@@ -305,6 +305,12 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -305,6 +305,12 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name; VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name;
if (var_nodes_[in_var_name].back() != in_node) {
VLOG(4) << "SKIP since " << in_var_name
<< " is also used as output by other ops";
continue;
}
bool can_replace = true; bool can_replace = true;
if (in_var_name == out_var_name) { if (in_var_name == out_var_name) {
can_replace = false; can_replace = false;
...@@ -527,6 +533,9 @@ void GraphView::Build(ir::Graph* g) { ...@@ -527,6 +533,9 @@ void GraphView::Build(ir::Graph* g) {
}; };
for (auto& node : g->Nodes()) { for (auto& node : g->Nodes()) {
if (!node->IsOp()) continue; if (!node->IsOp()) continue;
// avoid optimize the variable used in sub-blocks
if (OpHasSubBlock(node->Op())) update_skip_set(node);
if (node->Name() == "send") update_skip_set(node); if (node->Name() == "send") update_skip_set(node);
if (node->Name() == "recv") update_skip_set(node); if (node->Name() == "recv") update_skip_set(node);
if (node->Name() == "prefetch") update_skip_set(node); if (node->Name() == "prefetch") update_skip_set(node);
......
...@@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) { ...@@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) {
return type_size * std::abs(size); return type_size * std::abs(size);
} }
size_t NodeSize(ir::Node* n) { size_t NodeSize(ir::Node* n) { return NodeSize(*(n->Var())); }
VarDesc* desc = nullptr;
// some op do not have block pointer
if (n->inputs[0]->Op() != nullptr) {
desc = FindVarDescInBlock(n);
} else {
desc = n->Var();
}
return NodeSize(*desc);
}
std::string DebugStringImpl(VarDesc* var) { std::string DebugStringImpl(VarDesc* var) {
std::stringstream ss; std::stringstream ss;
...@@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) { ...@@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) {
} }
std::string DebugString(ir::Node* var) { std::string DebugString(ir::Node* var) {
return DebugStringImpl(FindVarDescInBlock(var)); return DebugStringImpl(GetVarDesc(var));
} }
// NOTE(dzh): based ir node, if a large node has been reused // NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will // by a small size node, then next time it appear in pool, it will
// have the small size. Find the original node shap from blockdesc. // have the small size. Find the original node shap from blockdesc.
VarDesc* FindVarDescInBlock(ir::Node* n) { VarDesc* GetVarDesc(ir::Node* n) {
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1); PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
BlockDesc* block = n->inputs[0]->Op()->Block(); return n->Var();
PADDLE_ENFORCE(block->HasVar(n->Name()),
string::Sprintf("Block do not has var %s", n->Name()));
return block->FindVar(n->Name());
} }
struct NodeComparator { struct NodeComparator {
bool operator()(ir::Node* lhs, ir::Node* rhs) const { bool operator()(ir::Node* lhs, ir::Node* rhs) const {
auto* lhs_desc = FindVarDescInBlock(lhs); if (lhs->Var()->GetType() != rhs->Var()->GetType()) return false;
auto* rhs_desc = FindVarDescInBlock(rhs); auto* lhs_desc = GetVarDesc(lhs);
auto* rhs_desc = GetVarDesc(rhs);
// match data type // match data type
if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) { if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) {
return false; return false;
...@@ -204,7 +193,7 @@ void OrderedSet::Insert(ir::Node* var) { ...@@ -204,7 +193,7 @@ void OrderedSet::Insert(ir::Node* var) {
return; return;
} }
auto* var_desc = FindVarDescInBlock(var); auto* var_desc = var->Var();
auto var_shape = var_desc->GetShape(); auto var_shape = var_desc->GetShape();
int batch_size = static_cast<int>(var_shape[0]); int batch_size = static_cast<int>(var_shape[0]);
...@@ -212,7 +201,7 @@ void OrderedSet::Insert(ir::Node* var) { ...@@ -212,7 +201,7 @@ void OrderedSet::Insert(ir::Node* var) {
Iter it = nodes_.begin(); Iter it = nodes_.begin();
while (it != nodes_.end()) { while (it != nodes_.end()) {
auto& prev = it->front(); auto& prev = it->front();
auto* cache_desc = FindVarDescInBlock(prev); auto* cache_desc = GetVarDesc(prev);
int cache_batch_size = cache_desc->GetShape()[0]; int cache_batch_size = cache_desc->GetShape()[0];
if ((cache_batch_size == -1 && batch_size == -1) || if ((cache_batch_size == -1 && batch_size == -1) ||
(cache_batch_size != -1 && batch_size != -1)) { (cache_batch_size != -1 && batch_size != -1)) {
...@@ -336,10 +325,16 @@ int MinChunkSize() { ...@@ -336,10 +325,16 @@ int MinChunkSize() {
bool NodeCanReused(const VarDesc& node) { bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType(); auto type = node.GetType();
// only these types holds bulk of gpu memory // only these types holds bulk of gpu memory
if (!(type == proto::VarType::LOD_TENSOR || // FIXME(liuwei1031) did not find good ways to test SELECTED_ROWS and
type == proto::VarType::LOD_TENSOR_ARRAY)) { // LOD_TENSOR_ARRAY re-use logic,
return false; // disable them in version 1.4
} // if (!(type == proto::VarType::LOD_TENSOR ||
// type == proto::VarType::SELECTED_ROWS ||
// type == proto::VarType::LOD_TENSOR_ARRAY)) {
// return false;
// }
if (type != proto::VarType::LOD_TENSOR) return false;
// persistable variable is parameter // persistable variable is parameter
if (node.Persistable()) { if (node.Persistable()) {
return false; return false;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <map> #include <map>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -140,11 +141,7 @@ size_t NodeSize(const VarDesc&); ...@@ -140,11 +141,7 @@ size_t NodeSize(const VarDesc&);
std::string DebugString(ir::Node* var); std::string DebugString(ir::Node* var);
// NOTE(dzhwinter) VarDesc* GetVarDesc(ir::Node* n);
// after node reuse, the replaced node shape is
// different with its VarDesc. So need to find the
// correct VarDesc in Block.
VarDesc* FindVarDescInBlock(ir::Node* n);
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) { static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() && return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
......
...@@ -198,8 +198,22 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const { ...@@ -198,8 +198,22 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
static_cast<bool>(boost::get<int>(node->Op()->GetAttr( static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward)); static_cast<int>(OpRole::kBackward));
// optimize op is already processed in DealWithSpecialOp,
// here we only consider backward op
if (!is_bk_op) continue; if (!is_bk_op) continue;
/*
* the op that will generate the gradient of on parameter will have
one attr op_role_var
* to record the parameter and gradient, like:
attrs {
name: "op_role_var"
type: STRINGS
strings: "fc_1.b_0"
strings: "fc_1.b_0@GRAD"
}
*/
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
auto backward_vars = auto backward_vars =
...@@ -256,6 +270,8 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp( ...@@ -256,6 +270,8 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
break; break;
} }
VLOG(3) << "loss_scale: " << loss_scale;
if (loss_scale) { if (loss_scale) {
// TODO(paddle-dev): Why is there no input for this op_handle? // TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
...@@ -407,7 +423,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp( ...@@ -407,7 +423,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { size_t dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id)); local_scopes_[dev_id], places_[dev_id], dev_id));
...@@ -494,9 +510,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps( ...@@ -494,9 +510,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
} }
} }
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result, VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
const std::string &og, ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
int dst_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
...@@ -643,7 +658,7 @@ bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -643,7 +658,7 @@ bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
if (UseGPU()) { if (UseGPU()) {
if (strategy_.fuse_broadcast_op_) { if (strategy_.fuse_broadcast_ops_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_); CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else { } else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
...@@ -774,6 +789,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -774,6 +789,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else if (OpHaveRole(*node, OpRole::kDist)) { } else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node); int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
// the input(block of parameter) of concat is on different device,
// the output(parameter) will on one device.
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set_[op_dev_id].emplace(origin_param_name); bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
} }
...@@ -781,6 +798,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -781,6 +798,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else { } else {
int op_dev_id = GetOpDeviceID(node); int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
// optimize op will be processed here.
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id); sharded_var_device_.emplace(n->Name(), op_dev_id);
...@@ -961,6 +979,7 @@ bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const { ...@@ -961,6 +979,7 @@ bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
// collective gradient to each device
size_t cur_device_id = 0; size_t cur_device_id = 0;
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
...@@ -1002,7 +1021,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { ...@@ -1002,7 +1021,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
return; return;
} }
if (strategy_.fuse_broadcast_op_) { if (strategy_.fuse_broadcast_ops_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_); CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else { } else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
...@@ -1049,3 +1068,5 @@ REGISTER_MULTI_DEVICES_PASS( ...@@ -1049,3 +1068,5 @@ REGISTER_MULTI_DEVICES_PASS(
paddle::framework::details::AllReduceSSAGraphBuilder); paddle::framework::details::AllReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass, REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass,
paddle::framework::details::DistSSAGraphBuilder); paddle::framework::details::DistSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(async_multi_devices_pass,
paddle::framework::details::AsyncSSAGraphBuilder);
...@@ -56,8 +56,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -56,8 +56,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool UseGPU() const; bool UseGPU() const;
bool NeedCollectiveForGrad(const std::string &grad_name, virtual bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const; std::vector<ir::Node *> ops) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
...@@ -70,10 +70,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -70,10 +70,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
proto::VarType::Type dtype) const; proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; size_t dev_id) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
...@@ -115,6 +115,35 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { ...@@ -115,6 +115,35 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
virtual void InsertPostprocessOps(ir::Graph *result) const {} virtual void InsertPostprocessOps(ir::Graph *result) const {}
}; };
class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const override {}
bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const {
return false;
}
bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override {
if (node->Op()->Type() == "recv") {
VLOG(1) << "set recv op do_not_run to true";
node->Op()->SetAttr("do_not_run", true);
node->Op()->Flush();
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
node->Name() == "hierarchical_sigmoid") {
// in async_mode, we do not need remote prefetch, because communicator
// will do async parameter recv.
VLOG(1) << "set " << node->Name() << " op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
node->Op()->Flush();
}
return false;
}
void InsertPostprocessOps(ir::Graph *result) const override {}
};
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected: protected:
int GetVarDeviceID(const std::string &varname) const; int GetVarDeviceID(const std::string &varname) const;
......
...@@ -68,7 +68,7 @@ void OpHandleBase::Run(bool use_cuda) { ...@@ -68,7 +68,7 @@ void OpHandleBase::Run(bool use_cuda) {
if (out_var_handle) { if (out_var_handle) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::is_same_place(place, out_var_handle->place()), platform::is_same_place(place, out_var_handle->place()),
"The place of input(%s) is not consistent with the " "The place of output(%s) is not consistent with the "
"place of current op(%s).", "place of current op(%s).",
out_var_handle->Name(), Name()); out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_.at(dev_id)); out_var_handle->SetGenerateEvent(events_.at(dev_id));
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -183,6 +184,10 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -183,6 +184,10 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
T maker(fwd_op, no_grad_set, grad_to_var, grad_block); T maker(fwd_op, no_grad_set, grad_to_var, grad_block);
return maker(); return maker();
}; };
info->use_default_grad_op_desc_maker_ =
std::is_base_of<DefaultGradOpDescMaker<true>, T>::value ||
std::is_base_of<DefaultGradOpDescMaker<false>, T>::value;
} }
}; };
...@@ -228,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> { ...@@ -228,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
} }
}; };
// A fake OpInfoFiller of void
template <>
struct OpInfoFiller<void, kUnknown> {
void operator()(const char* op_type, OpInfo* info) const {}
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -31,11 +31,23 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -31,11 +31,23 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
prepare_pool_(1), prepare_pool_(1),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr) { : nullptr) {
if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
for (auto *node : graph_->Nodes()) {
if (node->IsOp() && node->Name() == "read") {
read_op_num++;
}
}
if (read_op_num == 0) {
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!";
}
}
PrepareOpDeps(); PrepareOpDeps();
CopyOpDeps(); CopyOpDeps();
} }
FeedFetchList ThreadedSSAGraphExecutor::Run( inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
std::unique_ptr<platform::RecordEvent> event( std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare")); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
...@@ -68,7 +80,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -68,7 +80,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
set.clear(); set.clear();
}; };
auto run_all_op = [&](OpHandleBase *op) { RunOp(ready_vars, op); };
// Clean run context // Clean run context
run_op_futures_.clear(); run_op_futures_.clear();
exception_holder_.Clear(); exception_holder_.Clear();
...@@ -84,6 +95,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -84,6 +95,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars->PopAll(1, &timeout); auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
...@@ -102,7 +115,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -102,7 +115,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto &deps = pending_ops[op]; auto &deps = pending_ops[op];
--deps; --deps;
if (deps == 0) { if (deps == 0) {
run_all_op(op); ready_ops.insert(op);
} }
} }
} }
...@@ -114,6 +127,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -114,6 +127,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return fetch_data; return fetch_data;
} }
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
RunImpl({});
}
return RunImpl(fetch_tensors);
}
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
......
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "ThreadPool.h" // ThreadPool in thrird party
#include <ThreadPool.h> // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h" #include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
...@@ -59,6 +61,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -59,6 +61,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() final = default; ~ThreadedSSAGraphExecutor() final = default;
private: private:
inline FeedFetchList RunImpl(const std::vector<std::string> &fetch_tensors);
void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q, void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op); details::OpHandleBase *op);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
void DeviceWorker::SetRootScope(Scope* root_scope) { root_scope_ = root_scope; }
void DeviceWorker::SetDataFeed(const std::shared_ptr<DataFeed>& data_feed) {
device_reader_ = data_feed;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
class PullDenseWorker {
public:
virtual ~PullDenseWorker() {}
virtual void Initialize(const TrainerDesc& param);
int Start();
void Stop();
void SetRootScope(Scope* scope) { root_scope_ = scope; }
void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec);
static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker());
}
return s_instance_;
}
private:
PullDenseWorker() : root_scope_(NULL) {}
void Run();
bool CheckUpdateParam(uint64_t table_id);
private:
static std::shared_ptr<PullDenseWorker> s_instance_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
PullDenseWorkerParameter param_;
DownpourWorkerParameter dwp_param_;
Scope* root_scope_;
bool running_;
static std::map<uint64_t, uint64_t> last_versions_;
static std::map<uint64_t, uint64_t> current_version_;
static std::mutex mutex_for_version_;
static std::map<uint64_t, std::vector<uint64_t>> training_versions_;
static std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::thread t_;
int thread_num_;
int sleep_time_ms_;
int threshold_;
std::vector<::std::future<int32_t>> pull_dense_status_;
uint32_t pull_dense_fail_times_ = 0;
std::vector<float> base_norm_param_;
std::vector<float> mean_;
std::vector<float> scale_;
float squared_sum_epsilon_ = 1e-4;
std::mutex mutex_for_mean_scale_;
float total_batch_num_ = 0;
};
// should incorporate different type of device
class DeviceWorker {
public:
DeviceWorker() {}
virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0;
virtual void TrainFiles() = 0;
virtual void PrintFetchVars() = 0;
virtual void TrainFilesWithProfiler() = 0;
virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
// will make this zero copy in the future
virtual void BindingDataFeedMemory() = 0;
virtual void SetRootScope(Scope* root_scope);
virtual void SetDataFeed(const std::shared_ptr<DataFeed>& data_feed);
virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place;
}
protected:
Scope* root_scope_;
paddle::platform::Place place_;
std::shared_ptr<DataFeed> device_reader_;
int64_t batch_num_;
FetchConfig fetch_config_;
};
class CPUWorkerBase : public DeviceWorker {
public:
CPUWorkerBase() {}
virtual ~CPUWorkerBase() {}
virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
virtual void TrainFiles() = 0;
virtual void TrainFilesWithProfiler() {}
virtual void PrintFetchVars() {}
virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}
protected:
int thread_id_;
};
class HogwildWorker : public CPUWorkerBase {
public:
HogwildWorker() {}
virtual ~HogwildWorker() {}
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
virtual void PrintFetchVars();
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory();
protected:
void CreateThreadOperators(const ProgramDesc& program);
void CreateThreadScope(const ProgramDesc& program);
std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_;
Scope* thread_scope_;
HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_;
};
class DownpourWorker : public HogwildWorker {
public:
DownpourWorker() {}
virtual ~DownpourWorker() {}
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
protected:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
void FillSparseValue(size_t table_id);
void PushGradients();
void CollectLabelInfo(size_t table_id);
private:
bool need_to_push_dense_;
bool need_to_push_sparse_;
DownpourWorkerParameter param_;
// just save the value in param_ for easy access
std::map<uint64_t, std::string> label_var_name_;
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
// feasign
std::map<uint64_t, std::vector<uint64_t>> features_;
// feasign stats
std::map<uint64_t, std::vector<float>> feature_labels_;
// feasign embedding
std::map<uint64_t, std::vector<std::vector<float>>> feature_values_;
// feasign embedding gradient
std::map<uint64_t, std::vector<std::vector<float>>> feature_grads_;
// skipped ops
std::vector<std::string> skip_ops_;
std::shared_ptr<PullDenseWorker> _pull_dense_worker;
std::vector<::std::future<int32_t>> push_sparse_status_;
std::vector<::std::future<int32_t>> push_dense_status_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
namespace paddle {
namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std::string DeviceWorkerFactory::DeviceWorkerTypeList() {
std::string device_worker_types;
for (auto iter = g_device_worker_map.begin();
iter != g_device_worker_map.end(); ++iter) {
if (iter != g_device_worker_map.begin()) {
device_worker_types += ", ";
}
device_worker_types += iter->first;
}
return device_worker_types;
}
std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
std::string device_worker_class) {
if (g_device_worker_map.count(device_worker_class) < 1) {
exit(-1);
}
return g_device_worker_map[device_worker_class]();
}
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
class DeviceWorkerFactory {
public:
static std::string DeviceWorkerTypeList();
static std::shared_ptr<DeviceWorker> CreateDeviceWorker(
std::string device_worker_class);
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
TEST() {
// create hogwild device worker
}
}
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
dataset->CreateReaders();
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
thread_num_ = readers.size();
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc);
}
VLOG(3) << "going to initialize pull dense worker";
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
VLOG(3) << "initialize pull dense worker";
SetDebug(trainer_desc.debug());
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
pull_dense_worker_->SetRootScope(root_scope_);
pull_dense_worker_->Start();
VLOG(3) << "init other env done.";
}
void DistMultiTrainer::Run() {
for (int thidx = 0; thidx < thread_num_; ++thidx) {
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
workers_[thidx].get()));
}
}
}
void DistMultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
pull_dense_worker_->Stop();
dataset_ptr_->DestroyReaders();
root_scope_->DropKids();
}
} // end namespace framework
} // end namespace paddle
此差异已折叠。
...@@ -18,14 +18,16 @@ limitations under the License. */ ...@@ -18,14 +18,16 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
...@@ -115,6 +117,35 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -115,6 +117,35 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
} }
} }
void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset,
const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
&trainer_desc);
VLOG(3) << "Going to create trainer, trainer class is "
<< trainer_desc.class_name();
std::shared_ptr<TrainerBase> trainer;
trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
// initialize trainer
VLOG(3) << "Going to initialize trainer";
trainer->Initialize(trainer_desc, dataset);
VLOG(3) << "Set root scope here";
trainer->SetScope(scope);
// prepare training environment and helper environment
VLOG(3) << "Try to init train environment";
trainer->InitTrainerEnv(main_program, place_);
VLOG(3) << "Try to init other environment";
trainer->InitOtherEnv(main_program);
// training and finalize training
VLOG(3) << "Trainer starts to run";
trainer->Run();
VLOG(3) << "Trainer going to finalize";
trainer->Finalize();
return;
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars, bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars, const std::vector<std::string>& skip_ref_cnt_vars,
......
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -110,6 +112,9 @@ class Executor { ...@@ -110,6 +112,9 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset, const std::string& trainer_desc_str);
private: private:
const platform::Place place_; const platform::Place place_;
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor_thread_worker.h" #include "paddle/fluid/framework/executor_thread_worker.h"
#include <algorithm> #include <algorithm>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
...@@ -244,6 +245,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() { ...@@ -244,6 +245,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
SetDevice(); SetDevice();
thread_reader_->Start(); thread_reader_->Start();
std::vector<double> op_total_time; std::vector<double> op_total_time;
std::vector<std::string> op_name; std::vector<std::string> op_name;
for (auto& op : ops_) { for (auto& op : ops_) {
...@@ -273,7 +275,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() { ...@@ -273,7 +275,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
++batch_cnt; ++batch_cnt;
thread_scope_->DropKids(); thread_scope_->DropKids();
if (thread_id_ == 0) { if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % 1000 == 0) { if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < ops_.size(); ++i) { for (size_t i = 0; i < ops_.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i, fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt); op_name[i].c_str(), op_total_time[i] / batch_cnt);
...@@ -283,6 +285,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() { ...@@ -283,6 +285,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
for (int i = 0; i < fetch_var_num; ++i) { for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]); print_fetch_var(thread_scope_, fetch_var_names_[i]);
} }
fprintf(stderr, "IO percent: %f\n", read_time / total_time);
} }
} }
timeline.Start(); timeline.Start();
...@@ -293,7 +296,7 @@ void ExecutorThreadWorker::TrainFiles() { ...@@ -293,7 +296,7 @@ void ExecutorThreadWorker::TrainFiles() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
// todo: configurable // todo: configurable
SetDevice(); // SetDevice();
int fetch_var_num = fetch_var_names_.size(); int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear(); fetch_values_.clear();
...@@ -513,7 +516,6 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) { ...@@ -513,7 +516,6 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) {
auto& push_g = _feature_push_value[table_id]; auto& push_g = _feature_push_value[table_id];
check_pull_push_memory(features, &push_g, fea_dim); check_pull_push_memory(features, &push_g, fea_dim);
collect_feasign_info(table_id); collect_feasign_info(table_id);
} }
......
if(WITH_PSLIB)
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope pslib_brpc pslib)
else()
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope)
endif(WITH_PSLIB)
此差异已折叠。
此差异已折叠。
...@@ -147,7 +147,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { ...@@ -147,7 +147,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDesc>> operator()() const { std::vector<std::unique_ptr<OpDesc>> operator()() const final {
std::vector<std::unique_ptr<OpDesc>> retv; std::vector<std::unique_ptr<OpDesc>> retv;
retv.emplace_back(this->Apply()); retv.emplace_back(this->Apply());
return retv; return retv;
...@@ -158,14 +158,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { ...@@ -158,14 +158,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
}; };
template <bool DropEmptyIG = true> template <bool DropEmptyIG = true>
class DefaultGradOpDescMaker : public SingleGradOpDescMaker { class DefaultGradOpDescMaker final : public SingleGradOpDescMaker {
public: public:
using SingleGradOpDescMaker::SingleGradOpDescMaker; using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
virtual std::unique_ptr<OpDesc> Apply() const { std::unique_ptr<OpDesc> Apply() const final {
auto* grad = new OpDesc(); auto* grad = new OpDesc();
grad->SetType(this->GradOpType()); grad->SetType(this->ForwardOpType() + "_grad");
for (auto& input_param : this->InputNames()) { for (auto& input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param)); grad->SetInput(input_param, this->Input(input_param));
...@@ -182,18 +182,12 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { ...@@ -182,18 +182,12 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
return std::unique_ptr<OpDesc>(grad); return std::unique_ptr<OpDesc>(grad);
} }
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
}; };
class EmptyGradOpMaker : public GradOpDescMakerBase { class EmptyGradOpMaker final : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDesc>> operator()() const override { std::vector<std::unique_ptr<OpDesc>> operator()() const final { return {}; }
return {};
}
}; };
} // namespace framework } // namespace framework
......
此差异已折叠。
cc_library(fs SRCS fs.cc DEPS string_helper glog boost)
cc_library(shell SRCS shell.cc DEPS string_helper glog)
此差异已折叠。
此差异已折叠。
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fcntl.h>
#include <sys/stat.h>
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/syscall.h>
#endif
#include <sys/types.h>
#ifndef _WIN32
#include <sys/wait.h>
#endif
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
inline bool& shell_verbose_internal() {
static bool x = false;
return x;
}
inline bool shell_verbose() { return shell_verbose_internal(); }
inline void shell_set_verbose(bool x) { shell_verbose_internal() = x; }
extern std::shared_ptr<FILE> shell_fopen(const std::string& path,
const std::string& mode);
extern std::shared_ptr<FILE> shell_popen(const std::string& cmd,
const std::string& mode, int* err_no);
extern std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(
const std::string& cmd);
inline void shell_execute(const std::string& cmd) {
int err_no = 0;
do {
err_no = 0;
shell_popen(cmd, "w", &err_no);
} while (err_no == -1);
}
extern std::string shell_get_command_output(const std::string& cmd);
} // namespace framework
} // namespace paddle
...@@ -68,21 +68,13 @@ pass_library(transpose_flatten_concat_fuse_pass inference) ...@@ -68,21 +68,13 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base) pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base) pass_library(runtime_context_cache_pass base)
pass_library(simplify_anakin_detection_pattern_pass inference) pass_library(expected_kernel_cache_pass base)
pass_library(anakin_fillconstant_elementwisemul_fuse inference) pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference)
# There may be many transpose-flatten structures in a model, and the output of if(ANAKIN_FOUND)
# these structures will be used as inputs to the concat Op. This pattern will pass_library(simplify_anakin_priorbox_detection_out_pass inference)
# be detected by our pass. The index here represents the number of structures in the endif()
# pattern. We use index 3 ~ 6, because these quantities of structures are
# common in the models.
foreach (index RANGE 2 6)
file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n")
endforeach()
foreach (index RANGE 2 6)
file(APPEND ${pass_file} "USE_PASS(simplify_anakin_detection_pattern_pass${index});\n")
endforeach()
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base mkldnn) pass_library(mkldnn_placement_pass base mkldnn)
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册