未验证 提交 76b49f02 编写于 作者: G guru4elephant 提交者: GitHub

Merge pull request #16539 from guru4elephant/train_with_pipe_reader_merge_develop

Train with pipe reader merge develop
......@@ -15,7 +15,9 @@ paddle.fluid.cpu_places (ArgSpec(args=['device_count'], varargs=None, keywords=N
paddle.fluid.cuda_pinned_places (ArgSpec(args=['device_count'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd0c3ebd813c39958c92b78e3eef7e912'))
paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f5369953dd0c443961cf79f7a00e1a03'))
paddle.fluid.Executor.infer_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100)), ('document', '9c7decb955b9c4f718114179c8985581'))
paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', 'f482e93b38b4018796969a2e1dde479d'))
paddle.fluid.Executor.train_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100)), ('document', 'd521011d79e71080fe9b5bb179b43518'))
paddle.fluid.global_scope (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'e148d3ab1ed8edf3e928212a375959c0'))
paddle.fluid.scope_guard (ArgSpec(args=['scope'], varargs=None, keywords=None, defaults=None), ('document', 'b94d1f6bcc29c4fb58fc0058561250c2'))
paddle.fluid.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......@@ -36,15 +38,15 @@ paddle.fluid.DataFeedDesc.desc (ArgSpec(args=['self'], varargs=None, keywords=No
paddle.fluid.DataFeedDesc.set_batch_size (ArgSpec(args=['self', 'batch_size'], varargs=None, keywords=None, defaults=None), ('document', '8d9f44601e0a99dd431f14fd9250cd21'))
paddle.fluid.DataFeedDesc.set_dense_slots (ArgSpec(args=['self', 'dense_slots_name'], varargs=None, keywords=None, defaults=None), ('document', 'eb894b464bbcd1b4bc8038398954f766'))
paddle.fluid.DataFeedDesc.set_use_slots (ArgSpec(args=['self', 'use_slots_name'], varargs=None, keywords=None, defaults=None), ('document', '415c56600ce4e198c071cad01409a690'))
paddle.fluid.AsyncExecutor.__init__ (ArgSpec(args=['self', 'place', 'run_mode'], varargs=None, keywords=None, defaults=(None, '')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.AsyncExecutor.config_distributed_nodes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '4810dbe1870452f16b3c60b6c5fd1459'))
paddle.fluid.AsyncExecutor.download_data (ArgSpec(args=['self', 'afs_path', 'local_path', 'fs_default_name', 'ugi', 'file_cnt', 'hadoop_home', 'process_num'], varargs=None, keywords=None, defaults=('$HADOOP_HOME', 12)), ('document', '799a2066cc26819f1ed31f47c15ad083'))
paddle.fluid.AsyncExecutor.__init__ (ArgSpec(args=['self', 'place', 'run_mode'], varargs=None, keywords=None, defaults=(None, '')), ('document', '4e85874dddcd06c38f5717992d741589'))
paddle.fluid.AsyncExecutor.config_distributed_nodes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '762980fe0181eb41e3d1081b26ed76b1'))
paddle.fluid.AsyncExecutor.download_data (ArgSpec(args=['self', 'afs_path', 'local_path', 'fs_default_name', 'ugi', 'file_cnt', 'hadoop_home', 'process_num'], varargs=None, keywords=None, defaults=('$HADOOP_HOME', 12)), ('document', '39e3ccddf8ea8db75ea85287c9147c3b'))
paddle.fluid.AsyncExecutor.get_instance (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f8688f76a2db1243c7097a60c507b182'))
paddle.fluid.AsyncExecutor.init_model (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '504f39be2007404a17e5cabea1256c7d'))
paddle.fluid.AsyncExecutor.init_server (ArgSpec(args=['self', 'dist_desc'], varargs=None, keywords=None, defaults=None), ('document', 'c403ab46c5d3ef25c0f7e94ae75dcb68'))
paddle.fluid.AsyncExecutor.init_worker (ArgSpec(args=['self', 'dist_desc', 'startup_program'], varargs=None, keywords=None, defaults=None), ('document', 'dcf08f4bf2f3282acf11391f5d39c536'))
paddle.fluid.AsyncExecutor.init_server (ArgSpec(args=['self', 'dist_desc'], varargs=None, keywords=None, defaults=None), ('document', '384fa5fbb99912db1baf7ef7784bd312'))
paddle.fluid.AsyncExecutor.init_worker (ArgSpec(args=['self', 'dist_desc', 'startup_program'], varargs=None, keywords=None, defaults=None), ('document', 'f0a36d7c8561039f60a6f6555c7fee0b'))
paddle.fluid.AsyncExecutor.run (ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'mode', 'debug'], varargs=None, keywords=None, defaults=('', False)), ('document', '848fc53484e8326f6325feea87fe955c'))
paddle.fluid.AsyncExecutor.save_model (ArgSpec(args=['self', 'save_path'], varargs=None, keywords=None, defaults=None), ('document', 'c8ac0dfcb3b187aba25d03af7fea56b2'))
paddle.fluid.AsyncExecutor.save_model (ArgSpec(args=['self', 'save_path'], varargs=None, keywords=None, defaults=None), ('document', '145b5c0da01bfff397142e51361f4b75'))
paddle.fluid.AsyncExecutor.stop (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '5f23d043607bb5d55e466ec3f578e093'))
paddle.fluid.CompiledProgram.__init__ (ArgSpec(args=['self', 'program_or_graph'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', 'a8c7793803cf976680d9478e378fa356'))
......
#windows treat symbolic file as a real file, which is different with unix
#We create a hidden file and compile it instead of origin source file.
function(windows_symbolic TARGET)
......@@ -22,9 +23,13 @@ endfunction()
add_subdirectory(ir)
add_subdirectory(details)
add_subdirectory(fleet)
add_subdirectory(io)
#ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
......@@ -129,9 +134,11 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto)
#Generate an empty \
#__init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
......@@ -165,14 +172,24 @@ else()
endif()
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS})
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS})
cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
......@@ -183,11 +200,15 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper)
if(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib timer)
else()
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper timer)
endif(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc dataset_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
variable_helper timer fs shell)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
......@@ -214,18 +235,18 @@ cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog)
# Get the current working branch
execute_process(
COMMAND git rev-parse --abbrev-ref HEAD
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_BRANCH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_BRANCH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# Get the latest abbreviated commit hash of the working branch
execute_process(
COMMAND git log -1 --format=%h
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS "commit: ${PADDLE_COMMIT}")
message(STATUS "branch: ${PADDLE_BRANCH}")
......
......@@ -26,212 +26,44 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#endif
namespace paddle {
namespace framework {
AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place)
: root_scope_(scope), place_(place) {}
void AsyncExecutor::CreateThreads(
ExecutorThreadWorker* worker, const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names, Scope* root_scope,
const int thread_index, const bool debug) {
worker->SetThreadId(thread_index);
worker->SetDebug(debug);
worker->SetRootScope(root_scope);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
#ifdef PADDLE_WITH_PSLIB
worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread);
worker->SetParamConfig(&_param_config);
#endif
}
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
const int thread_num, const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist) {
readers.resize(thread_num);
for (size_t i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
readers[i]->Init(data_feed_desc); // set batch_size and queue_size here
}
readers[0]->SetFileList(filelist);
}
#ifdef PADDLE_WITH_PSLIB
void AsyncExecutor::InitServer(const std::string& dist_desc, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
_pslib_ptr->init_server(dist_desc, index);
InitParamConfig();
fleet_ptr_ = FleetWrapper::GetInstance();
fleet_ptr_->InitServer(dist_desc, index);
}
void AsyncExecutor::InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list,
int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
_pslib_ptr->init_worker(
dist_desc, const_cast<uint64_t*>(host_sign_list.data()), node_num, index);
InitParamConfig();
fleet_ptr_ = FleetWrapper::GetInstance();
fleet_ptr_->InitWorker(dist_desc, host_sign_list, node_num, index);
}
uint64_t AsyncExecutor::StartServer() { return _pslib_ptr->run_server(); }
uint64_t AsyncExecutor::StartServer() { return fleet_ptr_->RunServer(); }
void AsyncExecutor::StopServer() { _pslib_ptr->stop_server(); }
void AsyncExecutor::StopServer() { fleet_ptr_->StopServer(); }
void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) {
_pslib_ptr->gather_servers(const_cast<uint64_t*>(host_sign_list.data()),
node_num);
}
void AsyncExecutor::InitParamConfig() {
for (int i = 0; i < _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param_size();
++i) {
if (_pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.table_class()
.find("SparseTable") != -1) {
_param_config.fea_dim = _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.accessor()
.fea_dim();
break;
}
}
_param_config.slot_dim = _param_config.fea_dim - 2;
_param_config.tmp_push_dense_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size();
++t) {
_param_config.skip_op.push_back(
_pslib_ptr->get_param()->trainer_param().skip_op(t));
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
std::vector<std::string> tmp_sparse_variable_name;
for (int i = 0u; i < table.slot_value_size(); ++i) {
tmp_sparse_variable_name.push_back(table.slot_value(i));
_param_config.slot_alias_to_table[table.slot_key(i)] = table.table_id();
}
std::vector<std::string> tmp_sparse_gradient_variable_name;
for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
tmp_sparse_gradient_variable_name.push_back(table.slot_gradient(i));
}
_param_config.slot_input_vec[table.table_id()] =
std::move(tmp_sparse_variable_name);
_param_config.gradient_var[table.table_id()] =
std::move(tmp_sparse_gradient_variable_name);
_param_config.sparse_table_id.push_back(table.table_id());
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t);
std::vector<std::string> tmp_dense_variable_name;
for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
tmp_dense_variable_name.push_back(table.dense_variable_name(i));
}
std::vector<std::string> tmp_dense_gradient_variable_name;
for (auto i = 0u; i < table.dense_gradient_variable_name_size(); ++i) {
tmp_dense_gradient_variable_name.push_back(
table.dense_gradient_variable_name(i));
}
_param_config.dense_variable_name[table.table_id()] =
std::move(tmp_dense_variable_name);
_param_config.dense_gradient_variable_name[table.table_id()] =
std::move(tmp_dense_gradient_variable_name);
_param_config.dense_table_id.push_back(table.table_id());
_param_config.dense_table_size.push_back(table.fea_dim());
}
fleet_ptr_->GatherServers(host_sign_list, node_num);
}
void AsyncExecutor::InitModel() {
for (auto table_id : _param_config.dense_table_id) {
std::vector<paddle::ps::Region> regions;
for (auto& t : _param_config.dense_variable_name[table_id]) {
Variable* var = root_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(local_random_engine()) * init_range;
}
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
// todo InitModel
void AsyncExecutor::InitModel() {}
auto push_status = _pslib_ptr->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push dense param failed, status[" << status << "]";
exit(-1);
}
}
}
void AsyncExecutor::SaveModel(const std::string& path) {
auto ret = _pslib_ptr->_worker_ptr->flush();
ret.wait();
ret = _pslib_ptr->_worker_ptr->save(path, 0);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0
LOG(FATAL) << "save model failed";
exit(-1);
}
}
void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
if (mode == "mpi") {
DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;
param.threshold = 1;
param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_;
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread =
std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
}
}
#endif
// todo SaveModel
void AsyncExecutor::SaveModel(const std::string& path) {}
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str,
......@@ -256,14 +88,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc);
actual_thread_num = thread_num;
actual_thread_num_ = thread_num;
int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
if (actual_thread_num > file_cnt) {
if (actual_thread_num_ > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing thread_num = " << file_cnt;
actual_thread_num = file_cnt;
actual_thread_num_ = file_cnt;
}
/*
......@@ -279,12 +111,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
*/
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
/*
PrepareReaders(readers, actual_thread_num_, data_feed_desc, filelist);
#ifdef PADDLE_WITH_PSLIB
PrepareDenseThread(mode);
#endif
*/
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num);
workers.resize(actual_thread_num_);
for (auto& worker : workers) {
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
......@@ -298,13 +132,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
}
// prepare thread resource here
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
/*
for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
CreateThreads(workers[thidx].get(), main_program, readers[thidx],
fetch_var_names, root_scope_, thidx, debug);
}
*/
// start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
if (debug) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer,
workers[thidx].get()));
......@@ -317,15 +153,19 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) {
th.join();
}
// TODO(guru4elephant): we don't need this
/*
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
_pull_dense_thread->stop();
}
#endif
*/
VLOG(3) << "start to run from files in async_executor";
VLOG(3) << "Drop current scope kids";
root_scope_->DropKids();
return;
}
} // einit_modelnd namespace framework
} // end namespace framework
} // end namespace paddle
......@@ -25,8 +25,10 @@ limitations under the License. */
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
......@@ -65,9 +67,10 @@ class AsyncExecutor {
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_names,
const std::string& mode, const bool debug = false);
#ifdef PADDLE_WITH_PSLIB
const std::vector<std::string>& fetch_var_names,
const std::string& mode, const bool debug);
// TODO(guru4elephant): make init server decoupled from executor
void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, int node_num,
......@@ -77,31 +80,14 @@ class AsyncExecutor {
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
void InitModel();
void SaveModel(const std::string& path);
void InitParamConfig();
#endif
private:
void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index,
const bool debug);
#ifdef PADDLE_WITH_PSLIB
void PrepareDenseThread(const std::string& mode);
#endif
public:
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread;
AsyncWorkerParamConfig _param_config;
#endif
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
Scope* root_scope_;
platform::Place place_;
private:
int actual_thread_num;
int actual_thread_num_;
};
} // namespace framework
......
......@@ -33,6 +33,14 @@ class BlockingQueue {
cv_.notify_one();
}
void Push(T &&item) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(std::move(item));
}
cv_.notify_one();
}
template <typename U>
void Extend(const U &items) {
{
......@@ -44,6 +52,17 @@ class BlockingQueue {
cv_.notify_all();
}
template <typename U>
void Extend(U &&items) {
{
std::lock_guard<std::mutex> g(mutex_);
for (auto &item : items) {
q_.emplace_back(std::move(item));
}
}
cv_.notify_all();
}
std::deque<T> PopAll(size_t ms, bool *timeout) {
auto time =
std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
......@@ -64,6 +83,18 @@ class BlockingQueue {
return rc;
}
void Pop(T *t) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); });
*t = std::move(q_.front());
q_.pop_front();
}
size_t Size() {
std::lock_guard<std::mutex> lock(mutex_);
return q_.size();
}
private:
std::mutex mutex_;
std::condition_variable cv_;
......
此差异已折叠。
......@@ -15,17 +15,23 @@ limitations under the License. */
#pragma once
#include <fstream>
#include <future> // NOLINT
#include <memory>
#include <mutex> // NOLINT
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
......@@ -48,7 +54,10 @@ namespace framework {
// }
class DataFeed {
public:
DataFeed() {}
DataFeed() {
mutex_for_pick_file_ = nullptr;
file_idx_ = nullptr;
}
virtual ~DataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) {
......@@ -59,6 +68,7 @@ class DataFeed {
// Otherwise, Init() function will init finish_set_filelist_ flag.
virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0;
// The trainer calls the Next() function, and the DataFeed will load a new
// batch to the feed_vec. The return value of this function is the batch
// size of the current batch.
......@@ -74,6 +84,36 @@ class DataFeed {
// This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name);
// This function will do nothing at default
virtual void SetMemoryData(void* memory_data) {}
// This function will do nothing at default
virtual void SetMemoryDataMutex(std::mutex* mutex) {}
// This function will do nothing at default
virtual void SetThreadId(int thread_id) {}
// This function will do nothing at default
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
virtual void LocalShuffle() {
PADDLE_THROW("This function(LocalShuffle) is not implemented.");
}
virtual void GlobalShuffle() {
PADDLE_THROW("This function(GlobalShuffle) is not implemented.");
}
// This function will do nothing at default
virtual void FillMemoryDataToChannel() {}
// This function will do nothing at default
virtual void FillChannelToMemoryData() {}
// This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) {}
protected:
// The following three functions are used to check if it is executed in this
// order:
......@@ -87,9 +127,9 @@ class DataFeed {
// safe).
virtual bool PickOneFile(std::string* filename);
static std::vector<std::string> filelist_;
static size_t file_idx_;
static std::mutex mutex_for_pick_file_;
std::vector<std::string> filelist_;
size_t* file_idx_;
std::mutex* mutex_for_pick_file_;
// the alias of used slots, and its order is determined by
// data_feed_desc(proto object)
......@@ -112,8 +152,9 @@ class DataFeed {
int batch_size_;
bool finish_init_;
static bool finish_set_filelist_;
bool finish_set_filelist_;
bool finish_start_;
std::string pipe_command_;
};
// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
......@@ -136,6 +177,7 @@ class PrivateQueueDataFeed : public DataFeed {
virtual void SetQueueSize(int queue_size);
// The reading and parsing method called in the ReadThread.
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
// This function is used to put instance to vec_ins
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
int index) = 0;
......@@ -150,11 +192,58 @@ class PrivateQueueDataFeed : public DataFeed {
// ifstream one line and one line parse: 6034 ms
// fread one buffer and one buffer parse: 7097 ms
std::ifstream file_;
std::shared_ptr<FILE> fp_;
size_t queue_size_;
string::LineFileReader reader_;
// The queue for store parsed data
std::unique_ptr<paddle::operators::reader::BlockingQueue<T>> queue_;
};
template <typename T>
class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
InMemoryDataFeed();
virtual ~InMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
virtual void SetMemoryData(void* memory_data);
virtual void SetMemoryDataMutex(std::mutex* mutex);
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void PutInsToChannel(const std::string& ins_str);
virtual void FillMemoryDataToChannel();
virtual void FillChannelToMemoryData();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const std::vector<T*>& ins, std::string* str) = 0;
virtual void DeserializeIns(std::vector<T>* ins, const std::string& str) = 0;
virtual std::pair<int64_t, int64_t> GetMemoryDataInterval();
int thread_id_;
int thread_num_;
int trainer_num_;
uint32_t rand_seed;
std::vector<T>* memory_data_;
std::mutex* mutex_for_update_memory_data_;
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int cur_channel_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
class MultiSlotType {
public:
......@@ -176,6 +265,7 @@ class MultiSlotType {
offset_[0] = 0;
}
const std::vector<size_t>& GetOffset() const { return offset_; }
std::vector<size_t>& MutableOffset() { return offset_; }
void AddValue(const float v) {
CheckFloat();
float_feasign_.push_back(v);
......@@ -198,8 +288,33 @@ class MultiSlotType {
}
}
const std::vector<float>& GetFloatData() const { return float_feasign_; }
std::vector<float>& MutableFloatData() { return float_feasign_; }
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
const std::string& GetType() const { return type_; }
std::string& MutableType() { return type_; }
std::string DebugString() {
std::stringstream ss;
ss << "\ntype: " << type_ << "\n";
ss << "offset: ";
ss << "[";
for (const size_t& i : offset_) {
ss << offset_[i] << ",";
}
ss << "]\ndata: [";
if (type_[0] == 'f') {
for (const float& i : float_feasign_) {
ss << i << ",";
}
} else {
for (const uint64_t& i : uint64_feasign_) {
ss << i << ",";
}
}
ss << "]\n";
return ss.str();
}
private:
void CheckType(const std::string& type) const {
......@@ -228,13 +343,37 @@ class MultiSlotDataFeed
virtual ~MultiSlotDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual bool CheckFile(const char* filename);
// virtual void ReadThread();
protected:
virtual void ReadThread();
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance,
int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
};
class MultiSlotInMemoryDataFeed
: public InMemoryDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance,
int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<std::vector<MultiSlotType>*>& ins,
std::string* str);
virtual void DeserializeIns(std::vector<std::vector<MultiSlotType>>* ins,
const std::string& str);
};
} // namespace framework
} // namespace paddle
......@@ -27,4 +27,6 @@ message DataFeedDesc {
optional string name = 1;
optional int32 batch_size = 2 [ default = 32 ];
optional MultiSlotDesc multi_slot_desc = 3;
optional string pipe_command = 4;
optional int32 thread_num = 5;
}
......@@ -54,11 +54,15 @@ std::string DataFeedFactory::DataFeedTypeList() {
std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) {
LOG(WARNING) << "Your DataFeed " << data_feed_class
<< "is not supported currently";
LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList();
exit(-1);
}
return g_data_feed_map[data_feed_class]();
}
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
} // namespace framework
} // namespace paddle
......@@ -324,7 +324,7 @@ TEST(DataFeed, MultiSlotUnitTest) {
load_datafeed_param_from_file(protofile);
std::vector<MultiTypeSet> reader_elem_set;
std::vector<MultiTypeSet> file_elem_set;
GetElemSetFromReader(&reader_elem_set, data_feed_desc, filelist, 4);
GetElemSetFromFile(&file_elem_set, data_feed_desc, filelist);
CheckIsUnorderedSame(reader_elem_set, file_elem_set);
// GetElemSetFromReader(&reader_elem_set, data_feed_desc, filelist, 4);
// GetElemSetFromFile(&file_elem_set, data_feed_desc, filelist);
// CheckIsUnorderedSame(reader_elem_set, file_elem_set);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/data_set.h"
#include <random>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/timer.h"
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
namespace paddle {
namespace framework {
// constructor
template <typename T>
DatasetImpl<T>::DatasetImpl() {
thread_num_ = 1;
trainer_num_ = 1;
file_idx_ = 0;
}
// set filelist, file_idx_ will reset to zero.
template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist;
file_idx_ = 0;
}
// set expect thread num. actually it may change
template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) {
VLOG(3) << "SetThreadNum thread_num=" << thread_num;
thread_num_ = thread_num;
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerNum
template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
// should inform reader of trainer_num directly
for (auto reader : readers_) {
reader->SetTrainerNum(trainer_num);
}
}
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
fs_name_ = fs_name;
fs_ugi_ = fs_ugi;
std::string cmd = std::string("hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
paddle::framework::hdfs_set_command(cmd);
}
template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc_);
}
// readers_.size() may not be equal to thread_num_,
// it changes when filelist_.size() < thread_num_
template <typename T>
std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
DatasetImpl<T>::GetReaders() {
return readers_;
}
// if sent message between workers, should first call this function
template <typename T>
void DatasetImpl<T>::RegisterClientToClientMsgHandler() {
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
VLOG(3) << "RegisterClientToClientMsgHandler done";
}
// load data into memory, Dataset hold this memory,
// which will later be fed into readers' channel
template <typename T>
void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> load_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
}
for (std::thread& t : load_threads) {
t.join();
}
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
<< ", memory data size=" << memory_data_.size()
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
// release memory data
template <typename T>
void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}
// do local shuffle
template <typename T>
void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::LocalShuffle, readers_[i].get()));
}
for (std::thread& t : local_shuffle_threads) {
t.join();
}
std::vector<T>().swap(memory_data_);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::GlobalShuffle, readers_[i].get()));
}
for (std::thread& t : global_shuffle_threads) {
t.join();
}
std::vector<T>().swap(memory_data_);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0";
int file_cnt = filelist_.size();
int memory_data_size = memory_data_.size();
if (memory_data_size != 0 && thread_num_ > memory_data_size) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", memory data size = " << memory_data_size
<< ". Changing Dataset thread num = " << memory_data_size;
thread_num_ = memory_data_size;
} else if (file_cnt != 0 && thread_num_ > file_cnt) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing Dataset thread num = " << file_cnt;
thread_num_ = file_cnt;
}
VLOG(3) << "thread_num in Readers: " << thread_num_;
VLOG(3) << "readers size: " << readers_.size();
VLOG(3) << "Filelist size in readers: " << filelist_.size();
if (readers_.size() != 0) {
return;
}
VLOG(3) << "data feed class name: " << data_feed_desc_.name();
for (int i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_);
readers_.back()->SetMemoryData(&memory_data_);
readers_.back()->SetMemoryDataMutex(&mutex_for_update_memory_data_);
readers_.back()->SetThreadId(i);
readers_.back()->SetThreadNum(thread_num_);
readers_.back()->SetTrainerNum(trainer_num_);
readers_.back()->SetFileListMutex(&mutex_for_pick_file_);
readers_.back()->SetFileListIndex(&file_idx_);
readers_.back()->SetFileList(filelist_);
}
}
template <typename T>
void DatasetImpl<T>::DestroyReaders() {
VLOG(3) << "Calling DestroyReaders()";
// clear memory_data_ before fill it
// because if LoadIntoMemory but no Shuffle,
// memory_data_ has empty data which has been std::move to channel
if (memory_data_.size() != 0) {
std::vector<T>().swap(memory_data_);
}
std::vector<std::thread> fill_threads;
for (int i = 0; i < thread_num_; ++i) {
fill_threads.push_back(
std::thread(&paddle::framework::DataFeed::FillChannelToMemoryData,
readers_[i].get()));
}
for (std::thread& t : fill_threads) {
t.join();
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "readers size: " << readers_.size();
// if memory_data_ is empty, which means it's not InMemory mode,
// so the next epoch should read all data again
if (memory_data_.size() == 0) {
file_idx_ = 0;
}
}
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
#ifdef _LINUX
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << msg.length();
auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = rand_r(&rand_seed) % thread_num_;
VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg);
#endif
return 0;
}
// explicit instantiation
template class DatasetImpl<std::vector<MultiSlotType>>;
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
// Dataset is a abstract class, which defines user interfaces
// Example Usage:
// Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
// dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
// dataset->SetThreadNum(1)
// dataset->CreateReaders();
// dataset->SetDataFeedDesc(your_data_feed_desc);
// dataset->LoadIntoMemory();
// dataset->SetTrainerNum(2);
// dataset->GlobalShuffle();
class Dataset {
public:
Dataset() {}
virtual ~Dataset() {}
// set file list
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
// set readers' num
virtual void SetThreadNum(int thread_num) = 0;
// set workers' num
virtual void SetTrainerNum(int trainer_num) = 0;
// set fs name and ugi
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
// set data fedd desc, which contains:
// data feed name, batch size, slots
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// get file list
virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num
virtual int GetThreadNum() = 0;
// get worker num
virtual int GetTrainerNum() = 0;
// get hdfs config
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
// get data fedd desc
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
// get readers, the reader num depend both on thread num
// and filelist size
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders() = 0;
// register message handler between workers
virtual void RegisterClientToClientMsgHandler() = 0;
// load all data into memory
virtual void LoadIntoMemory() = 0;
// release all memory data
virtual void ReleaseMemory() = 0;
// local shuffle data
virtual void LocalShuffle() = 0;
// global shuffle data
virtual void GlobalShuffle() = 0;
// create readers
virtual void CreateReaders() = 0;
// destroy readers
virtual void DestroyReaders() = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
};
// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
template <typename T>
class DatasetImpl : public Dataset {
public:
DatasetImpl();
virtual ~DatasetImpl() {}
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
}
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders();
virtual void RegisterClientToClientMsgHandler();
virtual void LoadIntoMemory();
virtual void ReleaseMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
virtual void CreateReaders();
virtual void DestroyReaders();
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<T> memory_data_;
std::mutex mutex_for_update_memory_data_;
int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
int trainer_num_;
std::vector<std::string> filelist_;
size_t file_idx_;
std::mutex mutex_for_pick_file_;
std::string fs_name_;
std::string fs_ugi_;
unsigned int rand_seed;
};
// use std::vector<MultiSlotType> as data type
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
public:
MultiSlotDataset() {}
virtual ~MultiSlotDataset() {}
};
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/dataset_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)();
typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap;
datasetMap g_dataset_map;
#define REGISTER_DATASET_CLASS(dataset_class) \
namespace { \
std::shared_ptr<Dataset> Creator_##dataset_class() { \
return std::shared_ptr<Dataset>(new dataset_class); \
} \
class __Registerer_##dataset_class { \
public: \
__Registerer_##dataset_class() { \
g_dataset_map[#dataset_class] = &Creator_##dataset_class; \
} \
}; \
__Registerer_##dataset_class g_registerer_##dataset_class; \
} // namespace
std::string DatasetFactory::DatasetTypeList() {
std::string dataset_types;
for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end(); ++iter) {
if (iter != g_dataset_map.begin()) {
dataset_types += ", ";
}
dataset_types += iter->first;
}
return dataset_types;
}
std::shared_ptr<Dataset> DatasetFactory::CreateDataset(
std::string dataset_class) {
if (g_dataset_map.count(dataset_class) < 1) {
LOG(WARNING) << "Your Dataset " << dataset_class
<< "is not supported currently";
LOG(WARNING) << "Supported Dataset: " << DatasetTypeList();
exit(-1);
}
return g_dataset_map[dataset_class]();
}
REGISTER_DATASET_CLASS(MultiSlotDataset);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
class DatasetFactory {
public:
static std::string DatasetTypeList();
static std::shared_ptr<Dataset> CreateDataset(std::string dataset_class);
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
void DeviceWorker::SetRootScope(Scope* root_scope) { root_scope_ = root_scope; }
void DeviceWorker::SetDataFeed(const std::shared_ptr<DataFeed>& data_feed) {
device_reader_ = data_feed;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
class PullDenseWorker {
public:
virtual ~PullDenseWorker() {}
virtual void Initialize(const TrainerDesc& param);
int Start();
void Stop();
void SetRootScope(Scope* scope) { root_scope_ = scope; }
void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec);
static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker());
}
return s_instance_;
}
private:
PullDenseWorker() : root_scope_(NULL) {}
void Run();
bool CheckUpdateParam(uint64_t table_id);
private:
static std::shared_ptr<PullDenseWorker> s_instance_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
PullDenseWorkerParameter param_;
DownpourWorkerParameter dwp_param_;
Scope* root_scope_;
bool running_;
static std::map<uint64_t, uint64_t> last_versions_;
static std::map<uint64_t, uint64_t> current_version_;
static std::mutex mutex_for_version_;
static std::map<uint64_t, std::vector<uint64_t>> training_versions_;
static std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::thread t_;
int thread_num_;
int sleep_time_ms_;
int threshold_;
std::vector<::std::future<int32_t>> pull_dense_status_;
uint32_t pull_dense_fail_times_ = 0;
std::vector<float> base_norm_param_;
std::vector<float> mean_;
std::vector<float> scale_;
float squared_sum_epsilon_ = 1e-4;
std::mutex mutex_for_mean_scale_;
float total_batch_num_ = 0;
};
// should incorporate different type of device
class DeviceWorker {
public:
DeviceWorker() {}
virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0;
virtual void TrainFiles() = 0;
virtual void PrintFetchVars() = 0;
virtual void TrainFilesWithProfiler() = 0;
virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
// will make this zero copy in the future
virtual void BindingDataFeedMemory() = 0;
virtual void SetRootScope(Scope* root_scope);
virtual void SetDataFeed(const std::shared_ptr<DataFeed>& data_feed);
virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place;
}
protected:
Scope* root_scope_;
paddle::platform::Place place_;
std::shared_ptr<DataFeed> device_reader_;
int64_t batch_num_;
FetchConfig fetch_config_;
};
class CPUWorkerBase : public DeviceWorker {
public:
CPUWorkerBase() {}
virtual ~CPUWorkerBase() {}
virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
virtual void TrainFiles() = 0;
virtual void TrainFilesWithProfiler() {}
virtual void PrintFetchVars() {}
virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}
protected:
int thread_id_;
};
class HogwildWorker : public CPUWorkerBase {
public:
HogwildWorker() {}
virtual ~HogwildWorker() {}
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
virtual void PrintFetchVars();
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory();
protected:
void CreateThreadOperators(const ProgramDesc& program);
void CreateThreadScope(const ProgramDesc& program);
std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_;
Scope* thread_scope_;
HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_;
};
class DownpourWorker : public HogwildWorker {
public:
DownpourWorker() {}
virtual ~DownpourWorker() {}
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
protected:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
void FillSparseValue(size_t table_id);
void PushGradients();
void CollectLabelInfo(size_t table_id);
private:
bool need_to_push_dense_;
bool need_to_push_sparse_;
DownpourWorkerParameter param_;
// just save the value in param_ for easy access
std::map<uint64_t, std::string> label_var_name_;
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
// feasign
std::map<uint64_t, std::vector<uint64_t>> features_;
// feasign stats
std::map<uint64_t, std::vector<float>> feature_labels_;
// feasign embedding
std::map<uint64_t, std::vector<std::vector<float>>> feature_values_;
// feasign embedding gradient
std::map<uint64_t, std::vector<std::vector<float>>> feature_grads_;
// skipped ops
std::vector<std::string> skip_ops_;
std::shared_ptr<PullDenseWorker> _pull_dense_worker;
std::vector<::std::future<int32_t>> push_sparse_status_;
std::vector<::std::future<int32_t>> push_dense_status_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
namespace paddle {
namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std::string DeviceWorkerFactory::DeviceWorkerTypeList() {
std::string device_worker_types;
for (auto iter = g_device_worker_map.begin();
iter != g_device_worker_map.end(); ++iter) {
if (iter != g_device_worker_map.begin()) {
device_worker_types += ", ";
}
device_worker_types += iter->first;
}
return device_worker_types;
}
std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
std::string device_worker_class) {
if (g_device_worker_map.count(device_worker_class) < 1) {
exit(-1);
}
return g_device_worker_map[device_worker_class]();
}
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
class DeviceWorkerFactory {
public:
static std::string DeviceWorkerTypeList();
static std::shared_ptr<DeviceWorker> CreateDeviceWorker(
std::string device_worker_class);
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
TEST() {
// create hogwild device worker
}
}
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
dataset->CreateReaders();
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
thread_num_ = readers.size();
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc);
}
VLOG(3) << "going to initialize pull dense worker";
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
VLOG(3) << "initialize pull dense worker";
SetDebug(trainer_desc.debug());
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
pull_dense_worker_->SetRootScope(root_scope_);
pull_dense_worker_->Start();
VLOG(3) << "init other env done.";
}
void DistMultiTrainer::Run() {
for (int thidx = 0; thidx < thread_num_; ++thidx) {
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
workers_[thidx].get()));
}
}
}
void DistMultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
pull_dense_worker_->Stop();
dataset_ptr_->DestroyReaders();
root_scope_->DropKids();
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace paddle {
namespace framework {
void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(i).table_id());
TableParameter table = param_.sparse_table(i);
sparse_key_names_[table_id].resize(table.sparse_key_name_size());
for (size_t j = 0; j < table.sparse_key_name_size(); ++j) {
sparse_key_names_[table_id][j] = table.sparse_key_name(j);
}
sparse_value_names_[table_id].resize(table.sparse_value_name_size());
for (size_t j = 0; j < table.sparse_value_name_size(); ++j) {
sparse_value_names_[table_id][j] = table.sparse_value_name(j);
}
sparse_grad_names_[table_id].resize(table.sparse_grad_name_size());
for (size_t j = 0; j < table.sparse_grad_name_size(); ++j) {
sparse_grad_names_[table_id][j] = table.sparse_grad_name(j);
}
label_var_name_[table_id] = table.label_var_name();
}
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id());
auto table = param_.dense_table(i);
dense_value_names_[table_id].resize(table.dense_value_name_size());
for (size_t j = 0; j < table.dense_value_name_size(); ++j) {
dense_value_names_[table_id][j] = table.dense_value_name(j);
}
dense_grad_names_[table_id].resize(table.dense_grad_name_size());
for (size_t j = 0; j < table.dense_grad_name_size(); ++j) {
dense_grad_names_[table_id][j] = table.dense_grad_name(j);
}
}
skip_ops_.resize(param_.skip_ops_size());
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
need_to_push_sparse_ = param_.push_sparse();
need_to_push_dense_ = param_.push_dense();
fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config();
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == table_id) {
table = i;
break;
}
}
auto& feature = features_[table_id];
auto& feature_label = feature_labels_[table_id];
feature_label.resize(feature.size());
Variable* var = thread_scope_->FindVar(label_var_name_[table_id]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* label_ptr = tensor->data<int64_t>();
int global_index = 0;
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
VLOG(3) << "sparse_key_names_[" << i
<< "]: " << sparse_key_names_[table_id][i];
Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]);
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int fea_idx = 0;
// tensor->lod()[0].size() == batch_size + 1
for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) {
// should be skipped feasign defined in protobuf
if (ids[fea_idx] == 0u) {
continue;
}
feature_label[global_index++] =
static_cast<float>(label_ptr[lod_idx - 1]);
}
}
}
CHECK(global_index == feature.size())
<< "expect fea info size:" << feature.size() << " real:" << global_index;
}
void DownpourWorker::FillSparseValue(size_t table_idx) {
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == table_id) {
table = i;
break;
}
}
auto& fea_value = feature_values_[table_id];
auto fea_idx = 0u;
std::vector<float> init_value(table.fea_dim());
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
std::string slot_name = sparse_key_names_[table_id][i];
std::string emb_slot_name = sparse_value_names_[table_id][i];
Variable* var = thread_scope_->FindVar(slot_name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
Variable* var_emb = thread_scope_->FindVar(emb_slot_name);
LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>();
float* ptr = tensor_emb->mutable_data<float>({len, table.emb_dim()},
platform::CPUPlace());
memset(ptr, 0, sizeof(float) * len * table.emb_dim());
auto& tensor_lod = tensor->lod()[0];
LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod);
for (auto index = 0u; index < len; ++index) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
sizeof(float) * table.emb_dim());
continue;
}
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2,
sizeof(float) * table.emb_dim());
fea_idx++;
}
}
}
void DownpourWorker::TrainFilesWithProfiler() {
VLOG(3) << "Begin to train files with profiler";
platform::SetNumThreads(1);
device_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op_name.push_back(op->Type());
}
}
VLOG(3) << "op name size: " << op_name.size();
op_total_time.resize(op_name.size());
for (size_t i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
platform::Timer timeline;
double total_time = 0.0;
double read_time = 0.0;
double pull_sparse_time = 0.0;
double collect_label_time = 0.0;
double fill_sparse_time = 0.0;
double push_sparse_time = 0.0;
double push_dense_time = 0.0;
int cur_batch;
int batch_cnt = 0;
uint64_t total_inst = 0;
timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) {
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
VLOG(3) << "program config size: " << param_.program_config_size();
for (size_t i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
timeline.Start();
fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid,
sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], table.fea_dim());
timeline.Pause();
pull_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
timeline.Start();
CollectLabelInfo(i);
timeline.Pause();
collect_label_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
timeline.Start();
FillSparseValue(i);
timeline.Pause();
fill_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
VLOG(3) << "Fill sparse value for all sparse table done.";
int run_op_idx = 0;
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
timeline.Start();
VLOG(3) << "Going to run op " << op_name[run_op_idx];
op->Run(*thread_scope_, place_);
VLOG(3) << "Op " << op_name[run_op_idx] << " Finished";
timeline.Pause();
op_total_time[run_op_idx++] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
}
if (need_to_push_sparse_) {
for (size_t i = 0;
i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
timeline.Start();
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
timeline.Pause();
push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
}
if (need_to_push_dense_) {
timeline.Start();
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
timeline.Pause();
push_dense_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
VLOG(3) << "push sparse and dense gradient done.";
int32_t tmp_push_dense_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
}
push_dense_status_.resize(0);
}
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
}
if (need_to_push_sparse_) {
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
}
push_sparse_status_.resize(0);
}
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
VLOG(3) << "going to increase thread version";
VLOG(3) << "push dense table id size: "
<< param_.program_config(0).push_dense_table_id_size();
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
}
PrintFetchVars();
thread_scope_->DropKids();
total_inst += cur_batch;
++batch_cnt;
if (thread_id_ == 0) {
// should be configured here
if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < op_total_time.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt);
}
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "pull sparse time percent: %f\n",
pull_sparse_time / total_time * 100);
fprintf(stderr, "collect label time percent: %f\n",
collect_label_time / total_time * 100);
fprintf(stderr, "fill sparse time percent: %f\n",
fill_sparse_time / total_time * 100);
fprintf(stderr, "push sparse time percent: %f\n",
push_sparse_time / total_time * 100);
fprintf(stderr, "push dense time percent: %f\n",
push_dense_time / total_time * 100);
fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
}
}
timeline.Start();
}
}
void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files";
platform::SetNumThreads(1);
device_reader_->Start();
int batch_cnt = 0;
int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) {
// pull sparse here
for (size_t i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid,
sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], table.fea_dim());
CollectLabelInfo(i);
FillSparseValue(i);
}
VLOG(3) << "fill sparse value for all sparse table done.";
// do computation here
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
}
}
if (need_to_push_sparse_) {
// push gradients here
for (size_t i = 0;
i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
}
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
VLOG(3) << "push dense gradient done.";
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t tmp_push_dense_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
}
push_dense_status_.resize(0);
}
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
}
if (need_to_push_sparse_) {
VLOG(3) << "push sparse gradient done.";
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
}
push_sparse_status_.resize(0);
}
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
}
PrintFetchVars();
thread_scope_->DropKids();
++batch_cnt;
}
}
} // end namespace framework
} // end namespace paddle
......@@ -18,14 +18,16 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
......@@ -115,6 +117,35 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
}
void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset,
const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
&trainer_desc);
VLOG(3) << "Going to create trainer, trainer class is "
<< trainer_desc.class_name();
std::shared_ptr<TrainerBase> trainer;
trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
// initialize trainer
VLOG(3) << "Going to initialize trainer";
trainer->Initialize(trainer_desc, dataset);
VLOG(3) << "Set root scope here";
trainer->SetScope(scope);
// prepare training environment and helper environment
VLOG(3) << "Try to init train environment";
trainer->InitTrainerEnv(main_program, place_);
VLOG(3) << "Try to init other environment";
trainer->InitOtherEnv(main_program);
// training and finalize training
VLOG(3) << "Trainer starts to run";
trainer->Run();
VLOG(3) << "Trainer going to finalize";
trainer->Finalize();
return;
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars,
......
......@@ -19,6 +19,8 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -110,6 +112,9 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset, const std::string& trainer_desc_str);
private:
const platform::Place place_;
};
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor_thread_worker.h"
#include <algorithm>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
......@@ -244,6 +245,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
platform::SetNumThreads(1);
SetDevice();
thread_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
......@@ -273,7 +275,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
++batch_cnt;
thread_scope_->DropKids();
if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % 1000 == 0) {
if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < ops_.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt);
......@@ -283,6 +285,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]);
}
fprintf(stderr, "IO percent: %f\n", read_time / total_time);
}
}
timeline.Start();
......@@ -293,7 +296,7 @@ void ExecutorThreadWorker::TrainFiles() {
platform::SetNumThreads(1);
// todo: configurable
SetDevice();
// SetDevice();
int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
......@@ -513,7 +516,6 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) {
auto& push_g = _feature_push_value[table_id];
check_pull_push_memory(features, &push_g, fea_dim);
collect_feasign_info(table_id);
}
......
if(WITH_PSLIB)
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope pslib_brpc pslib)
else()
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope)
endif(WITH_PSLIB)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
#ifdef PADDLE_WITH_PSLIB
template <class AR>
paddle::ps::Archive<AR>& operator<<(paddle::ps::Archive<AR>& ar,
const MultiSlotType& ins) {
ar << ins.GetType();
ar << ins.GetOffset();
ar << ins.GetFloatData();
ar << ins.GetUint64Data();
return ar;
}
template <class AR>
paddle::ps::Archive<AR>& operator>>(paddle::ps::Archive<AR>& ar,
MultiSlotType& ins) {
ar >> ins.MutableType();
ar >> ins.MutableOffset();
ar >> ins.MutableFloatData();
ar >> ins.MutableUint64Data();
return ar;
}
#endif
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
#endif
void FleetWrapper::InitServer(const std::string& dist_desc, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
VLOG(3) << "Going to init server";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_server(dist_desc, index);
is_initialized_ = true;
} else {
VLOG(3) << "Server can be initialized only once";
}
#endif
}
void FleetWrapper::InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list,
int node_num, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
VLOG(3) << "Going to init worker";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_worker(dist_desc,
const_cast<uint64_t*>(host_sign_list.data()),
node_num, index);
is_initialized_ = true;
} else {
VLOG(3) << "Worker can be initialized only once";
}
#endif
}
void FleetWrapper::StopServer() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to stop server";
pslib_ptr_->stop_server();
#endif
}
uint64_t FleetWrapper::RunServer() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to run server";
return pslib_ptr_->run_server();
#else
return 0;
#endif
}
void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to gather server ips";
pslib_ptr_->gather_servers(const_cast<uint64_t*>(host_sign_list.data()),
node_num);
#endif
}
void FleetWrapper::GatherClients(const std::vector<uint64_t>& host_sign_list) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to gather client ips";
size_t len = host_sign_list.size();
pslib_ptr_->gather_clients(const_cast<uint64_t*>(host_sign_list.data()), len);
#endif
}
std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to get client info";
return pslib_ptr_->get_client_info();
#endif
return std::vector<uint64_t>();
}
void FleetWrapper::CreateClient2ClientConnection() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to create client2client connection";
pslib_ptr_->create_client2client_connection();
#endif
}
void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim) {
#ifdef PADDLE_WITH_PSLIB
std::vector<::std::future<int32_t>> pull_sparse_status;
pull_sparse_status.resize(0);
fea_keys->clear();
fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) {
Variable* var = scope.FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) {
continue;
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
auto status = pslib_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
pull_sparse_status.push_back(std::move(status));
for (auto& t : pull_sparse_status) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
exit(-1);
}
}
#endif
}
void FleetWrapper::PullDenseVarsAsync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status) {
#ifdef PADDLE_WITH_PSLIB
auto& regions = _regions[tid];
regions.clear();
regions.resize(var_names.size());
for (auto i = 0u; i < var_names.size(); ++i) {
Variable* var = scope.FindVar(var_names[i]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel());
regions[i] = std::move(reg);
}
auto status =
pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid);
pull_dense_status->push_back(std::move(status));
#endif
}
void FleetWrapper::PullDenseVarsSync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
auto& regions = _regions[tid];
regions.clear();
regions.reserve(var_names.size());
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto status =
pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid);
status.wait();
#endif
}
void FleetWrapper::PushDenseParamSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
auto place = platform::CPUPlace();
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->mutable_data<float>(place);
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
CHECK(status == 0) << "push dense param failed, status[" << status << "]";
#endif
}
void FleetWrapper::PushDenseVarsSync(
Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {}
void FleetWrapper::PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* push_sparse_status) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int count = tensor->numel();
float* g = tensor->data<float>();
paddle::ps::Region reg(g, count);
regions.emplace_back(std::move(reg));
}
auto status = pslib_ptr_->_worker_ptr->push_dense(regions.data(),
regions.size(), table_id);
push_sparse_status->push_back(std::move(status));
#endif
}
void FleetWrapper::PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys, const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status) {
#ifdef PADDLE_WITH_PSLIB
int offset = 2;
uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
exit(-1);
}
float* g = g_tensor->data<float>();
Variable* var = scope.FindVar(sparse_key_names[i]);
CHECK(var != nullptr) << "var[" << sparse_key_names[i] << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
exit(-1);
}
int len = tensor->numel();
int64_t* ids = tensor->data<int64_t>();
push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) {
t.resize(emb_dim + offset);
}
for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) {
g += emb_dim;
continue;
}
CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
(*push_values)[fea_idx][0] = 1.0f;
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
g += emb_dim;
fea_idx++;
}
}
CHECK(fea_idx == fea_keys.size()) << "fea_idx: " << fea_idx
<< "features size: " << fea_keys.size();
std::vector<float*> push_g_vec;
for (auto i = 0u; i < fea_keys.size(); ++i) {
push_g_vec.push_back((*push_values)[i].data());
}
auto status = pslib_ptr_->_worker_ptr->push_sparse(
table_id, fea_keys.data(), (const float**)push_g_vec.data(),
fea_keys.size());
push_sparse_status->push_back(std::move(status));
#endif
}
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pslib_ptr_=" << pslib_ptr_;
VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr;
return pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type,
handler);
#else
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler"
<< " does nothing when no pslib";
#endif
return 0;
}
std::future<int32_t> FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) {
#ifdef PADDLE_WITH_PSLIB
return pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id,
msg);
#else
VLOG(0) << "FleetWrapper::SendClientToClientMsg"
<< " does nothing when no pslib";
#endif
return std::future<int32_t>();
}
template <typename T>
void FleetWrapper::Serialize(const std::vector<T*>& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
for (size_t i = 0; i < t.size(); ++i) {
ar << *(t[i]);
}
*str = std::string(ar.buffer(), ar.length());
#else
VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib";
#endif
}
template <typename T>
void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB
if (str.length() == 0) {
return;
}
paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
if (ar.cursor() == ar.finish()) {
return;
}
while (ar.cursor() < ar.finish()) {
t->push_back(ar.get<T>());
}
CHECK(ar.cursor() == ar.finish());
VLOG(3) << "Deserialize size " << t->size();
#else
VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib";
#endif
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
std::vector<std::vector<MultiSlotType>>*, const std::string&);
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#ifdef PADDLE_WITH_PSLIB
#include <archive.h>
#include <pslib.h>
#endif
#include <atomic>
#include <ctime>
#include <map>
#include <random>
#include <string>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace framework {
// A wrapper class for pslib.h, this class follows Singleton pattern
// i.e. only initialized once in the current process
// Example:
// std::shared_ptr<FleetWrapper> fleet_ptr =
// FleetWrapper::GetInstance();
// string dist_desc;
// fleet_ptr->InitServer(dist_desc, 0);
// interface design principles:
// Pull
// Sync: PullSparseVarsSync
// Async: PullSparseVarsAsync(not implemented currently)
// Push
// Sync: PushSparseVarsSync
// Async: PushSparseVarsAsync(not implemented currently)
// Async: PushSparseVarsWithLabelAsync(with special usage)
// Push dense variables to server in Async mode
// Param<in>: scope, table_id, var_names
// Param<out>: push_sparse_status
class FleetWrapper {
public:
virtual ~FleetWrapper() {}
FleetWrapper() {}
// Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values
void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values,
int fea_dim);
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
void PullDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status);
void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names,
// Param<out>: push_sparse_status
void PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* push_sparse_status);
void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push sparse variables with labels to server in Async mode
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, var_grad_names,
// fea_keys, fea_labels, sparse_grad_names
// Param<out>: push_values, push_sparse_status
void PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys,
const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status);
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
/*
void PushSparseVarsAsync(
const Scope& scope,
const uint64_t table_id,
const std::vector<uint64_t>& fea_keys,
const std::vector<std::string>& sparse_grad_names,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status);
*/
void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, int node_num,
int index);
void StopServer();
uint64_t RunServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
// gather client ip
void GatherClients(const std::vector<uint64_t>& host_sign_list);
// get client info
std::vector<uint64_t> GetClientsInfo();
// create client to client connection
void CreateClient2ClientConnection();
// register client to client communication
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
// send client to client message
std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg);
template <typename T>
void Serialize(const std::vector<T*>& t, std::string* str);
template <typename T>
void Deserialize(std::vector<T>* t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
}
return s_instance_;
}
#ifdef PADDLE_WITH_PSLIB
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif
private:
static std::shared_ptr<FleetWrapper> s_instance_;
#ifdef PADDLE_WITH_PSLIB
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
#endif
protected:
static bool is_initialized_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
namespace paddle {
namespace framework {
void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param();
skip_ops_.resize(param_.skip_ops_size());
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
}
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr);
continue;
}
}
void HogwildWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL(
root_scope_, "root_scope should be set before creating thread scope");
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
} else {
auto* ptr = thread_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
}
}
}
void HogwildWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed =
device_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
device_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
CreateThreadScope(main_prog);
CreateThreadOperators(main_prog);
}
void HogwildWorker::TrainFilesWithProfiler() {
platform::SetNumThreads(1);
device_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
op_name.push_back(op->Type());
}
op_total_time.resize(ops_.size());
for (size_t i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
platform::Timer timeline;
double total_time = 0.0;
double read_time = 0.0;
int cur_batch;
int batch_cnt = 0;
timeline.Start();
uint64_t total_inst = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
VLOG(3) << "read a batch in thread " << thread_id_;
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
for (size_t i = 0; i < ops_.size(); ++i) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
timeline.Start();
VLOG(3) << "Going to run op " << op_name[i];
if (!need_skip) {
ops_[i]->Run(*thread_scope_, place_);
}
VLOG(3) << "Op " << op_name[i] << " Finished";
timeline.Pause();
op_total_time[i] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
total_inst += cur_batch;
++batch_cnt;
PrintFetchVars();
if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < ops_.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt);
}
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
}
}
thread_scope_->DropKids();
timeline.Start();
}
}
void HogwildWorker::TrainFiles() {
platform::SetNumThreads(1);
// how to accumulate fetched values here
device_reader_->Start();
int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) {
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
}
}
PrintFetchVars();
thread_scope_->DropKids();
}
}
void HogwildWorker::PrintFetchVars() {
// call count
batch_num_++;
int batch_per_print = fetch_config_.print_period();
if (thread_id_ == 0) {
if (batch_num_ % batch_per_print == 0) {
int fetch_var_num = fetch_config_.fetch_var_names_size();
for (int i = 0; i < fetch_var_num; ++i) {
platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
fetch_config_.fetch_var_str_format(i));
}
}
}
}
} // end namespace framework
} // end namespace paddle
cc_library(fs SRCS fs.cc DEPS string_helper glog boost)
cc_library(shell SRCS shell.cc DEPS string_helper glog)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/io/fs.h"
#include <memory>
namespace paddle {
namespace framework {
static void fs_add_read_converter_internal(std::string& path, // NOLINT
bool& is_pipe, // NOLINT
const std::string& converter) {
if (converter == "") {
return;
}
if (!is_pipe) {
path = string::format_string("( %s ) < \"%s\"", converter.c_str(),
path.c_str());
is_pipe = true;
} else {
path = string::format_string("%s | %s", path.c_str(), converter.c_str());
}
}
static void fs_add_write_converter_internal(std::string& path, // NOLINT
bool& is_pipe, // NOLINT
const std::string& converter) {
if (converter == "") {
return;
}
if (!is_pipe) {
path = string::format_string("( %s ) > \"%s\"", converter.c_str(),
path.c_str());
is_pipe = true;
} else {
path = string::format_string("%s | %s", converter.c_str(), path.c_str());
}
}
static std::shared_ptr<FILE> fs_open_internal(const std::string& path,
bool is_pipe,
const std::string& mode,
size_t buffer_size,
int* err_no = 0) {
std::shared_ptr<FILE> fp = nullptr;
if (!is_pipe) {
fp = shell_fopen(path, mode);
} else {
fp = shell_popen(path, mode, err_no);
}
if (buffer_size > 0) {
char* buffer = new char[buffer_size];
CHECK_EQ(0, setvbuf(&*fp, buffer, _IOFBF, buffer_size));
fp = {&*fp, [fp, buffer](FILE*) mutable { // NOLINT
CHECK(fp.unique()); // NOLINT
fp = nullptr;
delete[] buffer;
}};
}
return fp;
}
static bool fs_begin_with_internal(const std::string& path,
const std::string& str) {
return strncmp(path.c_str(), str.c_str(), str.length()) == 0;
}
static bool fs_end_with_internal(const std::string& path,
const std::string& str) {
return path.length() >= str.length() &&
strncmp(&path[path.length() - str.length()], str.c_str(),
str.length()) == 0;
}
static size_t& localfs_buffer_size_internal() {
static size_t x = 0;
return x;
}
size_t localfs_buffer_size() { return localfs_buffer_size_internal(); }
void localfs_set_buffer_size(size_t x) { localfs_buffer_size_internal() = x; }
std::shared_ptr<FILE> localfs_open_read(std::string path,
const std::string& converter) {
bool is_pipe = false;
if (fs_end_with_internal(path, ".gz")) {
fs_add_read_converter_internal(path, is_pipe, "zcat");
}
fs_add_read_converter_internal(path, is_pipe, converter);
return fs_open_internal(path, is_pipe, "r", localfs_buffer_size());
}
std::shared_ptr<FILE> localfs_open_write(std::string path,
const std::string& converter) {
shell_execute(
string::format_string("mkdir -p $(dirname \"%s\")", path.c_str()));
bool is_pipe = false;
if (fs_end_with_internal(path, ".gz")) {
fs_add_write_converter_internal(path, is_pipe, "gzip");
}
fs_add_write_converter_internal(path, is_pipe, converter);
return fs_open_internal(path, is_pipe, "w", localfs_buffer_size());
}
int64_t localfs_file_size(const std::string& path) {
struct stat buf;
if (0 != stat(path.c_str(), &buf)) {
LOG(FATAL) << "file stat not zero";
return -1;
}
return (int64_t)buf.st_size;
}
void localfs_remove(const std::string& path) {
if (path == "") {
return;
}
shell_execute(string::format_string("rm -rf %s", path.c_str()));
}
std::vector<std::string> localfs_list(const std::string& path) {
if (path == "") {
return {};
}
std::shared_ptr<FILE> pipe;
int err_no = 0;
pipe = shell_popen(
string::format_string("find %s -type f -maxdepth 1", path.c_str()), "r",
&err_no);
string::LineFileReader reader;
std::vector<std::string> list;
while (reader.getline(&*pipe)) {
list.push_back(reader.get());
}
return list;
}
std::string localfs_tail(const std::string& path) {
if (path == "") {
return "";
}
return shell_get_command_output(
string::format_string("tail -1 %s ", path.c_str()));
}
bool localfs_exists(const std::string& path) {
std::string test_f = shell_get_command_output(
string::format_string("[ -f %s ] ; echo $?", path.c_str()));
if (string::trim_spaces(test_f) == "0") {
return true;
}
std::string test_d = shell_get_command_output(
string::format_string("[ -d %s ] ; echo $?", path.c_str()));
if (string::trim_spaces(test_d) == "0") {
return true;
}
return false;
}
void localfs_mkdir(const std::string& path) {
if (path == "") {
return;
}
shell_execute(string::format_string("mkdir -p %s", path.c_str()));
}
static size_t& hdfs_buffer_size_internal() {
static size_t x = 0;
return x;
}
size_t hdfs_buffer_size() { return hdfs_buffer_size_internal(); }
void hdfs_set_buffer_size(size_t x) { hdfs_buffer_size_internal() = x; }
static std::string& hdfs_command_internal() {
static std::string x = "hadoop fs";
return x;
}
const std::string& hdfs_command() { return hdfs_command_internal(); }
void hdfs_set_command(const std::string& x) { hdfs_command_internal() = x; }
std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no,
const std::string& converter) {
if (fs_end_with_internal(path, ".gz")) {
path = string::format_string("%s -text \"%s\"", hdfs_command().c_str(),
path.c_str());
} else {
path = string::format_string("%s -cat \"%s\"", hdfs_command().c_str(),
path.c_str());
}
bool is_pipe = true;
fs_add_read_converter_internal(path, is_pipe, converter);
return fs_open_internal(path, is_pipe, "r", hdfs_buffer_size(), err_no);
}
std::shared_ptr<FILE> hdfs_open_write(std::string path, int* err_no,
const std::string& converter) {
path = string::format_string("%s -put - \"%s\"", hdfs_command().c_str(),
path.c_str());
bool is_pipe = true;
if (fs_end_with_internal(path, ".gz\"")) {
fs_add_write_converter_internal(path, is_pipe, "gzip");
}
fs_add_write_converter_internal(path, is_pipe, converter);
return fs_open_internal(path, is_pipe, "w", hdfs_buffer_size(), err_no);
}
void hdfs_remove(const std::string& path) {
if (path == "") {
return;
}
shell_execute(string::format_string("%s -rmr %s &>/dev/null; true",
hdfs_command().c_str(), path.c_str()));
}
std::vector<std::string> hdfs_list(const std::string& path) {
if (path == "") {
return {};
}
std::string prefix = "hdfs:";
if (fs_begin_with_internal(path, "afs:")) {
prefix = "afs:";
}
int err_no = 0;
std::vector<std::string> list;
do {
err_no = 0;
std::shared_ptr<FILE> pipe;
pipe = shell_popen(
string::format_string("%s -ls %s | ( grep ^- ; [ $? != 2 ] )",
hdfs_command().c_str(), path.c_str()),
"r", &err_no);
string::LineFileReader reader;
list.clear();
while (reader.getline(&*pipe)) {
std::vector<std::string> line = string::split_string(reader.get());
if (line.size() != 8) {
continue;
}
list.push_back(prefix + line[7]);
}
} while (err_no == -1);
return list;
}
std::string hdfs_tail(const std::string& path) {
if (path == "") {
return "";
}
return shell_get_command_output(string::format_string(
"%s -text %s | tail -1 ", hdfs_command().c_str(), path.c_str()));
}
bool hdfs_exists(const std::string& path) {
std::string test = shell_get_command_output(string::format_string(
"%s -test -e %s ; echo $?", hdfs_command().c_str(), path.c_str()));
if (string::trim_spaces(test) == "0") {
return true;
}
return false;
}
void hdfs_mkdir(const std::string& path) {
if (path == "") {
return;
}
shell_execute(string::format_string("%s -mkdir %s; true",
hdfs_command().c_str(), path.c_str()));
}
int fs_select_internal(const std::string& path) {
if (fs_begin_with_internal(path, "hdfs:")) {
return 1;
} else if (fs_begin_with_internal(path, "afs:")) {
return 1;
}
return 0;
}
std::shared_ptr<FILE> fs_open_read(const std::string& path, int* err_no,
const std::string& converter) {
switch (fs_select_internal(path)) {
case 0:
return localfs_open_read(path, converter);
case 1:
return hdfs_open_read(path, err_no, converter);
default:
LOG(FATAL) << "Not supported";
}
return {};
}
std::shared_ptr<FILE> fs_open_write(const std::string& path, int* err_no,
const std::string& converter) {
switch (fs_select_internal(path)) {
case 0:
return localfs_open_write(path, converter);
case 1:
return hdfs_open_write(path, err_no, converter);
default:
LOG(FATAL) << "Not supported";
}
return {};
}
std::shared_ptr<FILE> fs_open(const std::string& path, const std::string& mode,
int* err_no, const std::string& converter) {
if (mode == "r" || mode == "rb") {
return fs_open_read(path, err_no, converter);
}
if (mode == "w" || mode == "wb") {
return fs_open_write(path, err_no, converter);
}
LOG(FATAL) << "Unknown mode: " << mode;
return {};
}
int64_t fs_file_size(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_file_size(path);
default:
LOG(FATAL) << "Not supported";
}
return 0;
}
void fs_remove(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_remove(path);
case 1:
return hdfs_remove(path);
default:
LOG(FATAL) << "Not supported";
}
}
std::vector<std::string> fs_list(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_list(path);
case 1:
return hdfs_list(path);
default:
LOG(FATAL) << "Not supported";
}
return {};
}
std::string fs_tail(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_tail(path);
case 1:
return hdfs_tail(path);
default:
LOG(FATAL) << "Not supported";
}
return "";
}
bool fs_exists(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_exists(path);
case 1:
return hdfs_exists(path);
default:
LOG(FATAL) << "Not supported";
}
return false;
}
void fs_mkdir(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_mkdir(path);
case 1:
return hdfs_mkdir(path);
default:
LOG(FATAL) << "Not supported";
}
}
} // end namespace framework
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <memory>
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
int fs_select_internal(const std::string& path);
// localfs
extern size_t localfs_buffer_size();
extern void localfs_set_buffer_size(size_t x);
extern std::shared_ptr<FILE> localfs_open_read(std::string path,
const std::string& converter);
extern std::shared_ptr<FILE> localfs_open_write(std::string path,
const std::string& converter);
extern int64_t localfs_file_size(const std::string& path);
extern void localfs_remove(const std::string& path);
extern std::vector<std::string> localfs_list(const std::string& path);
extern std::string localfs_tail(const std::string& path);
extern bool localfs_exists(const std::string& path);
extern void localfs_mkdir(const std::string& path);
// hdfs
extern size_t hdfs_buffer_size();
extern void hdfs_set_buffer_size(size_t x);
extern const std::string& hdfs_command();
extern void hdfs_set_command(const std::string& x);
extern std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no,
const std::string& converter);
extern std::shared_ptr<FILE> hdfs_open_write(std::string path, int* err_no,
const std::string& converter);
extern void hdfs_remove(const std::string& path);
extern std::vector<std::string> hdfs_list(const std::string& path);
extern std::string hdfs_tail(const std::string& path);
extern bool hdfs_exists(const std::string& path);
extern void hdfs_mkdir(const std::string& path);
// aut-detect fs
extern std::shared_ptr<FILE> fs_open_read(const std::string& path, int* err_no,
const std::string& converter);
extern std::shared_ptr<FILE> fs_open_write(const std::string& path, int* err_no,
const std::string& converter);
extern std::shared_ptr<FILE> fs_open(const std::string& path,
const std::string& mode, int* err_no,
const std::string& converter = "");
extern int64_t fs_file_size(const std::string& path);
extern void fs_remove(const std::string& path);
extern std::vector<std::string> fs_list(const std::string& path);
extern std::string fs_tail(const std::string& path);
extern bool fs_exists(const std::string& path);
extern void fs_mkdir(const std::string& path);
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/io/shell.h"
namespace paddle {
namespace framework {
std::shared_ptr<FILE> shell_fopen(const std::string& path,
const std::string& mode) {
#if defined _WIN32 || defined __APPLE__
return nullptr;
#else
if (shell_verbose()) {
LOG(INFO) << "Opening file[" << path << "] with mode[" << mode << "]";
}
FILE* fp;
if (!(fp = fopen(path.c_str(), mode.c_str()))) {
LOG(FATAL) << "fopen fail, path[" << path << "], mode[" << mode << "]";
}
return {fp, [path](FILE* fp) {
if (shell_verbose()) {
LOG(INFO) << "Closing file[" << path << "]";
}
if (0 != fclose(fp)) {
LOG(FATAL) << "fclose fail, path[" << path << "]";
}
}};
#endif
}
// Close all open file descriptors
// The implementation is async signal safe
// Mostly copy from CPython code
static int close_open_fds_internal() {
#if defined _WIN32 || defined __APPLE__
return 0;
#else
struct linux_dirent {
long d_ino = 0; // NOLINT
off_t d_off;
unsigned short d_reclen = 0; // NOLINT
char d_name[256];
};
int dir_fd = -1;
if ((dir_fd = open("/proc/self/fd", O_RDONLY)) < 0) {
LOG(FATAL) << "proc/self/fd open fail";
return -1;
}
char buffer[sizeof(linux_dirent)];
for (;;) {
int bytes = 0;
if ((bytes = syscall(SYS_getdents, dir_fd,
reinterpret_cast<linux_dirent*>(buffer),
sizeof(buffer))) < 0) {
LOG(FATAL) << "syscall fail";
return -1;
}
if (bytes == 0) {
break;
}
linux_dirent* entry = NULL;
for (int offset = 0; offset < bytes; offset += entry->d_reclen) {
entry = reinterpret_cast<linux_dirent*>(buffer + offset);
int fd = 0;
const char* s = entry->d_name;
while (*s >= '0' && *s <= '9') {
fd = fd * 10 + (*s - '0');
s++;
}
if (s != entry->d_name && fd != dir_fd && fd >= 3) {
close(fd);
}
}
}
close(dir_fd);
return 0;
#endif
}
static int shell_popen_fork_internal(const char* real_cmd, bool do_read,
int parent_end, int child_end) {
#if defined _WIN32 || defined __APPLE__
return 0;
#else
int child_pid = -1;
// Too frequent calls to fork() makes openmpi very slow. Use vfork() instead.
// But vfork() is very dangerous. Be careful.
if ((child_pid = vfork()) < 0) {
return -1;
}
// The following code is async signal safe (No memory allocation, no access to
// global data, etc.)
if (child_pid != 0) {
return child_pid;
}
int child_std_end = do_read ? 1 : 0;
close(parent_end);
if (child_end != child_std_end) {
if (dup2(child_end, child_std_end) != child_std_end) {
return -1;
}
close(child_end);
}
close_open_fds_internal();
if (execl("/bin/sh", "sh", "-c", real_cmd, NULL) < 0) {
return -1;
}
exit(127);
#endif
}
std::shared_ptr<FILE> shell_popen(const std::string& cmd,
const std::string& mode, int* err_no) {
#if defined _WIN32 || defined __APPLE__
return nullptr;
#else
bool do_read = mode == "r";
bool do_write = mode == "w";
if (!(do_read || do_write)) {
*err_no = -1;
return NULL;
}
if (shell_verbose()) {
LOG(INFO) << "Opening pipe[" << cmd << "] with mode[" << mode << "]";
}
std::string real_cmd = "set -o pipefail; " + cmd;
int pipe_fds[2];
if (pipe(pipe_fds) != 0) {
*err_no = -1;
return NULL;
}
int parent_end = 0;
int child_end = 0;
if (do_read) {
parent_end = pipe_fds[0];
child_end = pipe_fds[1];
} else if (do_write) {
parent_end = pipe_fds[1];
child_end = pipe_fds[0];
}
int child_pid = shell_popen_fork_internal(real_cmd.c_str(), do_read,
parent_end, child_end);
close(child_end);
fcntl(parent_end, F_SETFD, FD_CLOEXEC);
FILE* fp;
if ((fp = fdopen(parent_end, mode.c_str())) == NULL) {
*err_no = -1;
return NULL;
}
return {fp, [child_pid, cmd, err_no](FILE* fp) {
if (shell_verbose()) {
LOG(INFO) << "Closing pipe[" << cmd << "]";
}
if (fclose(fp) != 0) {
*err_no = -1;
}
int wstatus = -1;
waitpid(child_pid, &wstatus, 0);
if (wstatus == 0 || wstatus == (128 + SIGPIPE) * 256 ||
(wstatus == -1 && errno == ECHILD)) {
} else {
*err_no = -1;
LOG(WARNING) << "status[" << wstatus << "], cmd[" << cmd << "]"
<< ", err_no[" << *err_no << "]";
}
if (wstatus == -1 && errno == ECHILD) {
LOG(WARNING) << "errno is ECHILD";
}
}};
#endif
}
static int shell_p2open_fork_internal(const char* real_cmd, int pipein_fds[2],
int pipeout_fds[2]) {
#if defined _WIN32 || defined __APPLE__
return 0;
#else
int child_pid = -1;
if ((child_pid = fork()) < 0) {
return -1;
}
if (child_pid != 0) {
return child_pid;
}
close(pipein_fds[0]);
close(pipeout_fds[1]);
if (pipein_fds[1] != 1) {
if (dup2(pipein_fds[1], 1) != 1) {
return -1;
}
close(pipein_fds[1]);
}
if (pipeout_fds[0] != 0) {
if (dup2(pipeout_fds[0], 0) != 0) {
return -1;
}
close(pipeout_fds[0]);
}
close_open_fds_internal();
if (execl("/bin/sh", "sh", "-c", real_cmd, NULL) < 0) {
return -1;
}
exit(127);
#endif
}
std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(
const std::string& cmd) {
#if defined _WIN32 || defined __APPLE__
return {};
#else
if (shell_verbose()) {
LOG(INFO) << "Opening bidirectional pipe[" << cmd << "]";
}
std::string real_cmd = "set -o pipefail; " + cmd;
int pipein_fds[2];
int pipeout_fds[2];
if (pipe(pipein_fds) != 0) {
return {NULL, NULL};
}
if (pipe(pipeout_fds) != 0) {
return {NULL, NULL};
}
int child_pid =
shell_p2open_fork_internal(real_cmd.c_str(), pipein_fds, pipeout_fds);
close(pipein_fds[1]);
close(pipeout_fds[0]);
fcntl(pipein_fds[0], F_SETFD, FD_CLOEXEC);
fcntl(pipeout_fds[1], F_SETFD, FD_CLOEXEC);
std::shared_ptr<int> child_life = {
NULL, [child_pid, cmd](void*) {
if (shell_verbose()) {
LOG(INFO) << "Closing bidirectional pipe[" << cmd << "]";
}
int wstatus, ret;
do {
PCHECK((ret = waitpid(child_pid, &wstatus, 0)) >= 0 ||
(ret == -1 && errno == EINTR));
} while (ret == -1 && errno == EINTR);
PCHECK(wstatus == 0 || wstatus == (128 + SIGPIPE) * 256 ||
(wstatus == -1 && errno == ECHILD))
<< "status[" << wstatus << "], cmd[" << cmd << "]";
if (wstatus == -1 && errno == ECHILD) {
LOG(WARNING) << "errno is ECHILD";
}
}};
FILE* in_fp;
PCHECK((in_fp = fdopen(pipein_fds[0], "r")) != NULL);
FILE* out_fp;
PCHECK((out_fp = fdopen(pipeout_fds[1], "w")) != NULL);
return {{in_fp, [child_life](FILE* fp) { PCHECK(fclose(fp) == 0); }},
{out_fp, [child_life](FILE* fp) { PCHECK(fclose(fp) == 0); }}};
#endif
}
std::string shell_get_command_output(const std::string& cmd) {
#if defined _WIN32 || defined __APPLE__
return "";
#else
int err_no = 0;
do {
err_no = 0;
std::shared_ptr<FILE> pipe = shell_popen(cmd, "r", &err_no);
string::LineFileReader reader;
if (reader.getdelim(&*pipe, 0)) {
pipe = nullptr;
if (err_no == 0) {
return reader.get();
}
}
} while (err_no == -1);
return "";
#endif
}
} // end namespace framework
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fcntl.h>
#include <sys/stat.h>
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/syscall.h>
#endif
#include <sys/types.h>
#ifndef _WIN32
#include <sys/wait.h>
#endif
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
inline bool& shell_verbose_internal() {
static bool x = false;
return x;
}
inline bool shell_verbose() { return shell_verbose_internal(); }
inline void shell_set_verbose(bool x) { shell_verbose_internal() = x; }
extern std::shared_ptr<FILE> shell_fopen(const std::string& path,
const std::string& mode);
extern std::shared_ptr<FILE> shell_popen(const std::string& cmd,
const std::string& mode, int* err_no);
extern std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(
const std::string& cmd);
inline void shell_execute(const std::string& cmd) {
int err_no = 0;
do {
err_no = 0;
shell_popen(cmd, "w", &err_no);
} while (err_no == -1);
}
extern std::string shell_get_command_output(const std::string& cmd);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
// get filelist from trainer_desc here
dataset->CreateReaders();
VLOG(3) << "readers created";
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
// change thread num to readers num
thread_num_ = readers.size();
VLOG(3) << "worker thread num: " << thread_num_;
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]);
}
// set debug here
SetDebug(trainer_desc.debug());
}
// call only after all resources are set in current trainer
void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) {
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetPlace(place);
workers_[i]->SetRootScope(root_scope_);
workers_[i]->CreateDeviceResource(main_program); // Program
workers_[i]->BindingDataFeedMemory();
}
}
void MultiTrainer::Run() {
VLOG(3) << "Going to run";
for (int thidx = 0; thidx < thread_num_; ++thidx) {
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
workers_[thidx].get()));
}
}
}
void MultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
dataset_ptr_->DestroyReaders();
root_scope_->DropKids();
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL;
std::mutex PullDenseWorker::mutex_for_version_;
std::map<uint64_t, uint64_t> PullDenseWorker::last_versions_;
std::map<uint64_t, uint64_t> PullDenseWorker::current_version_;
std::map<uint64_t, std::vector<uint64_t>> PullDenseWorker::training_versions_;
std::map<uint64_t, std::vector<std::string>>
PullDenseWorker::dense_value_names_;
void PullDenseWorker::Initialize(const TrainerDesc& param) {
running_ = false;
param_ = param.pull_dense_param();
dwp_param_ = param.downpour_param();
threshold_ = param_.threshold();
thread_num_ = param_.device_num();
sleep_time_ms_ = param_.sleep_time_ms();
for (size_t i = 0;
i < dwp_param_.program_config(0).pull_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
dwp_param_.program_config(0).pull_dense_table_id(i));
TableParameter table;
for (auto i : param_.dense_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
// setup dense variables for each table
int var_num = table.dense_value_name_size();
dense_value_names_[tid].resize(var_num);
for (int j = 0; j < var_num; ++j) {
dense_value_names_[tid][j] = table.dense_value_name(j);
}
// setup training version for each table
training_versions_[tid].resize(thread_num_, 0);
last_versions_[tid] = 0;
current_version_[tid] = 0;
}
fleet_ptr_ = FleetWrapper::GetInstance();
}
void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
for (auto& t : *status_vec) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(WARNING) << "Current Pull Dense Thread Failed Times"
<< ++pull_dense_fail_times_;
}
}
int MAX_FAIL_NUM = 20;
if (pull_dense_fail_times_ > MAX_FAIL_NUM) {
LOG(FATAL) << "Pull Dense Failed Times More Than " << MAX_FAIL_NUM
<< " Times";
exit(-1);
}
status_vec->resize(0);
}
void PullDenseWorker::Stop() {
if (running_) {
running_ = false;
t_.join();
}
}
int PullDenseWorker::Start() {
running_ = true;
t_ = std::thread(&PullDenseWorker::Run, this);
return 0;
}
void PullDenseWorker::Run() {
while (running_) {
pull_dense_status_.resize(0);
for (size_t i = 0;
i < dwp_param_.program_config(0).pull_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
dwp_param_.program_config(0).pull_dense_table_id(i));
if (CheckUpdateParam(tid)) {
fleet_ptr_->PullDenseVarsAsync(
*root_scope_, tid, dense_value_names_[tid], &pull_dense_status_);
ResetThreadVersion(tid);
}
}
if (pull_dense_status_.size() != 0) {
Wait(&pull_dense_status_);
}
#ifndef _WIN32
usleep(sleep_time_ms_ * 1000);
#endif
}
}
void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) {
std::lock_guard<std::mutex> lock(mutex_for_version_);
training_versions_[table_id][thread_id]++;
}
bool PullDenseWorker::CheckUpdateParam(uint64_t table_id) {
std::lock_guard<std::mutex> lock(mutex_for_version_);
auto& version = training_versions_[table_id];
current_version_[table_id] =
*(std::min_element(version.begin(), version.end()));
if (current_version_[table_id] - last_versions_[table_id] < threshold_) {
return false;
}
return true;
}
void PullDenseWorker::ResetThreadVersion(uint64_t table_id) {
std::lock_guard<std::mutex> lock(mutex_for_version_);
last_versions_[table_id] = current_version_[table_id];
}
} // namespace framework
} // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -27,6 +27,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册