“7e3f5a00b456f444a4b3c52c0de5f3f6b21ad224”上不存在“git@gitcode.net:paddlepaddle/Serving.git”
未验证 提交 76b49f02 编写于 作者: G guru4elephant 提交者: GitHub

Merge pull request #16539 from guru4elephant/train_with_pipe_reader_merge_develop

Train with pipe reader merge develop
...@@ -15,7 +15,9 @@ paddle.fluid.cpu_places (ArgSpec(args=['device_count'], varargs=None, keywords=N ...@@ -15,7 +15,9 @@ paddle.fluid.cpu_places (ArgSpec(args=['device_count'], varargs=None, keywords=N
paddle.fluid.cuda_pinned_places (ArgSpec(args=['device_count'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd0c3ebd813c39958c92b78e3eef7e912')) paddle.fluid.cuda_pinned_places (ArgSpec(args=['device_count'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd0c3ebd813c39958c92b78e3eef7e912'))
paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f5369953dd0c443961cf79f7a00e1a03')) paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f5369953dd0c443961cf79f7a00e1a03'))
paddle.fluid.Executor.infer_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100)), ('document', '9c7decb955b9c4f718114179c8985581'))
paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', 'f482e93b38b4018796969a2e1dde479d')) paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', 'f482e93b38b4018796969a2e1dde479d'))
paddle.fluid.Executor.train_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100)), ('document', 'd521011d79e71080fe9b5bb179b43518'))
paddle.fluid.global_scope (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'e148d3ab1ed8edf3e928212a375959c0')) paddle.fluid.global_scope (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'e148d3ab1ed8edf3e928212a375959c0'))
paddle.fluid.scope_guard (ArgSpec(args=['scope'], varargs=None, keywords=None, defaults=None), ('document', 'b94d1f6bcc29c4fb58fc0058561250c2')) paddle.fluid.scope_guard (ArgSpec(args=['scope'], varargs=None, keywords=None, defaults=None), ('document', 'b94d1f6bcc29c4fb58fc0058561250c2'))
paddle.fluid.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
...@@ -36,15 +38,15 @@ paddle.fluid.DataFeedDesc.desc (ArgSpec(args=['self'], varargs=None, keywords=No ...@@ -36,15 +38,15 @@ paddle.fluid.DataFeedDesc.desc (ArgSpec(args=['self'], varargs=None, keywords=No
paddle.fluid.DataFeedDesc.set_batch_size (ArgSpec(args=['self', 'batch_size'], varargs=None, keywords=None, defaults=None), ('document', '8d9f44601e0a99dd431f14fd9250cd21')) paddle.fluid.DataFeedDesc.set_batch_size (ArgSpec(args=['self', 'batch_size'], varargs=None, keywords=None, defaults=None), ('document', '8d9f44601e0a99dd431f14fd9250cd21'))
paddle.fluid.DataFeedDesc.set_dense_slots (ArgSpec(args=['self', 'dense_slots_name'], varargs=None, keywords=None, defaults=None), ('document', 'eb894b464bbcd1b4bc8038398954f766')) paddle.fluid.DataFeedDesc.set_dense_slots (ArgSpec(args=['self', 'dense_slots_name'], varargs=None, keywords=None, defaults=None), ('document', 'eb894b464bbcd1b4bc8038398954f766'))
paddle.fluid.DataFeedDesc.set_use_slots (ArgSpec(args=['self', 'use_slots_name'], varargs=None, keywords=None, defaults=None), ('document', '415c56600ce4e198c071cad01409a690')) paddle.fluid.DataFeedDesc.set_use_slots (ArgSpec(args=['self', 'use_slots_name'], varargs=None, keywords=None, defaults=None), ('document', '415c56600ce4e198c071cad01409a690'))
paddle.fluid.AsyncExecutor.__init__ (ArgSpec(args=['self', 'place', 'run_mode'], varargs=None, keywords=None, defaults=(None, '')), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.AsyncExecutor.__init__ (ArgSpec(args=['self', 'place', 'run_mode'], varargs=None, keywords=None, defaults=(None, '')), ('document', '4e85874dddcd06c38f5717992d741589'))
paddle.fluid.AsyncExecutor.config_distributed_nodes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '4810dbe1870452f16b3c60b6c5fd1459')) paddle.fluid.AsyncExecutor.config_distributed_nodes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '762980fe0181eb41e3d1081b26ed76b1'))
paddle.fluid.AsyncExecutor.download_data (ArgSpec(args=['self', 'afs_path', 'local_path', 'fs_default_name', 'ugi', 'file_cnt', 'hadoop_home', 'process_num'], varargs=None, keywords=None, defaults=('$HADOOP_HOME', 12)), ('document', '799a2066cc26819f1ed31f47c15ad083')) paddle.fluid.AsyncExecutor.download_data (ArgSpec(args=['self', 'afs_path', 'local_path', 'fs_default_name', 'ugi', 'file_cnt', 'hadoop_home', 'process_num'], varargs=None, keywords=None, defaults=('$HADOOP_HOME', 12)), ('document', '39e3ccddf8ea8db75ea85287c9147c3b'))
paddle.fluid.AsyncExecutor.get_instance (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f8688f76a2db1243c7097a60c507b182')) paddle.fluid.AsyncExecutor.get_instance (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f8688f76a2db1243c7097a60c507b182'))
paddle.fluid.AsyncExecutor.init_model (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '504f39be2007404a17e5cabea1256c7d')) paddle.fluid.AsyncExecutor.init_model (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '504f39be2007404a17e5cabea1256c7d'))
paddle.fluid.AsyncExecutor.init_server (ArgSpec(args=['self', 'dist_desc'], varargs=None, keywords=None, defaults=None), ('document', 'c403ab46c5d3ef25c0f7e94ae75dcb68')) paddle.fluid.AsyncExecutor.init_server (ArgSpec(args=['self', 'dist_desc'], varargs=None, keywords=None, defaults=None), ('document', '384fa5fbb99912db1baf7ef7784bd312'))
paddle.fluid.AsyncExecutor.init_worker (ArgSpec(args=['self', 'dist_desc', 'startup_program'], varargs=None, keywords=None, defaults=None), ('document', 'dcf08f4bf2f3282acf11391f5d39c536')) paddle.fluid.AsyncExecutor.init_worker (ArgSpec(args=['self', 'dist_desc', 'startup_program'], varargs=None, keywords=None, defaults=None), ('document', 'f0a36d7c8561039f60a6f6555c7fee0b'))
paddle.fluid.AsyncExecutor.run (ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'mode', 'debug'], varargs=None, keywords=None, defaults=('', False)), ('document', '848fc53484e8326f6325feea87fe955c')) paddle.fluid.AsyncExecutor.run (ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'mode', 'debug'], varargs=None, keywords=None, defaults=('', False)), ('document', '848fc53484e8326f6325feea87fe955c'))
paddle.fluid.AsyncExecutor.save_model (ArgSpec(args=['self', 'save_path'], varargs=None, keywords=None, defaults=None), ('document', 'c8ac0dfcb3b187aba25d03af7fea56b2')) paddle.fluid.AsyncExecutor.save_model (ArgSpec(args=['self', 'save_path'], varargs=None, keywords=None, defaults=None), ('document', '145b5c0da01bfff397142e51361f4b75'))
paddle.fluid.AsyncExecutor.stop (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '5f23d043607bb5d55e466ec3f578e093')) paddle.fluid.AsyncExecutor.stop (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '5f23d043607bb5d55e466ec3f578e093'))
paddle.fluid.CompiledProgram.__init__ (ArgSpec(args=['self', 'program_or_graph'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.CompiledProgram.__init__ (ArgSpec(args=['self', 'program_or_graph'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', 'a8c7793803cf976680d9478e378fa356')) paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', 'a8c7793803cf976680d9478e378fa356'))
......
#windows treat symbolic file as a real file, which is different with unix #windows treat symbolic file as a real file, which is different with unix
#We create a hidden file and compile it instead of origin source file. #We create a hidden file and compile it instead of origin source file.
function(windows_symbolic TARGET) function(windows_symbolic TARGET)
...@@ -22,9 +23,13 @@ endfunction() ...@@ -22,9 +23,13 @@ endfunction()
add_subdirectory(ir) add_subdirectory(ir)
add_subdirectory(details) add_subdirectory(details)
add_subdirectory(fleet)
add_subdirectory(io)
#ddim lib #ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(async_executor_proto SRCS data_feed.proto) proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
...@@ -129,9 +134,11 @@ cc_test(version_test SRCS version_test.cc DEPS version) ...@@ -129,9 +134,11 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto)
#Generate an empty \ #Generate an empty \
#__init__.py to make framework_py_proto as a valid python module. #__init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
...@@ -165,14 +172,24 @@ else() ...@@ -165,14 +172,24 @@ else()
endif() endif()
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector) cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS}) dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS}) cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
...@@ -183,11 +200,15 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS ...@@ -183,11 +200,15 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy graph build_strategy
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)
if(WITH_PSLIB) cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib timer) executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
else() trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper timer) downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
endif(WITH_PSLIB) data_set.cc dataset_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
variable_helper timer fs shell)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor) cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
...@@ -214,18 +235,18 @@ cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog) ...@@ -214,18 +235,18 @@ cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
COMMAND git rev-parse --abbrev-ref HEAD COMMAND git rev-parse --abbrev-ref HEAD
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_BRANCH OUTPUT_VARIABLE PADDLE_BRANCH
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
) )
# Get the latest abbreviated commit hash of the working branch # Get the latest abbreviated commit hash of the working branch
execute_process( execute_process(
COMMAND git log -1 --format=%h COMMAND git log -1 --format=%h
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_COMMIT OUTPUT_VARIABLE PADDLE_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
) )
message(STATUS "commit: ${PADDLE_COMMIT}") message(STATUS "commit: ${PADDLE_COMMIT}")
message(STATUS "branch: ${PADDLE_BRANCH}") message(STATUS "branch: ${PADDLE_BRANCH}")
......
...@@ -26,212 +26,44 @@ limitations under the License. */ ...@@ -26,212 +26,44 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place) AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place)
: root_scope_(scope), place_(place) {} : root_scope_(scope), place_(place) {}
void AsyncExecutor::CreateThreads(
ExecutorThreadWorker* worker, const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names, Scope* root_scope,
const int thread_index, const bool debug) {
worker->SetThreadId(thread_index);
worker->SetDebug(debug);
worker->SetRootScope(root_scope);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
#ifdef PADDLE_WITH_PSLIB
worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread);
worker->SetParamConfig(&_param_config);
#endif
}
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
const int thread_num, const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist) {
readers.resize(thread_num);
for (size_t i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
readers[i]->Init(data_feed_desc); // set batch_size and queue_size here
}
readers[0]->SetFileList(filelist);
}
#ifdef PADDLE_WITH_PSLIB
void AsyncExecutor::InitServer(const std::string& dist_desc, int index) { void AsyncExecutor::InitServer(const std::string& dist_desc, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>( fleet_ptr_ = FleetWrapper::GetInstance();
new paddle::distributed::PSlib()); fleet_ptr_->InitServer(dist_desc, index);
_pslib_ptr->init_server(dist_desc, index);
InitParamConfig();
} }
void AsyncExecutor::InitWorker(const std::string& dist_desc, void AsyncExecutor::InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, const std::vector<uint64_t>& host_sign_list,
int node_num, int index) { int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>( fleet_ptr_ = FleetWrapper::GetInstance();
new paddle::distributed::PSlib()); fleet_ptr_->InitWorker(dist_desc, host_sign_list, node_num, index);
_pslib_ptr->init_worker(
dist_desc, const_cast<uint64_t*>(host_sign_list.data()), node_num, index);
InitParamConfig();
} }
uint64_t AsyncExecutor::StartServer() { return _pslib_ptr->run_server(); } uint64_t AsyncExecutor::StartServer() { return fleet_ptr_->RunServer(); }
void AsyncExecutor::StopServer() { _pslib_ptr->stop_server(); } void AsyncExecutor::StopServer() { fleet_ptr_->StopServer(); }
void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list, void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) { int node_num) {
_pslib_ptr->gather_servers(const_cast<uint64_t*>(host_sign_list.data()), fleet_ptr_->GatherServers(host_sign_list, node_num);
node_num);
}
void AsyncExecutor::InitParamConfig() {
for (int i = 0; i < _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param_size();
++i) {
if (_pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.table_class()
.find("SparseTable") != -1) {
_param_config.fea_dim = _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.accessor()
.fea_dim();
break;
}
}
_param_config.slot_dim = _param_config.fea_dim - 2;
_param_config.tmp_push_dense_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size();
++t) {
_param_config.skip_op.push_back(
_pslib_ptr->get_param()->trainer_param().skip_op(t));
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
std::vector<std::string> tmp_sparse_variable_name;
for (int i = 0u; i < table.slot_value_size(); ++i) {
tmp_sparse_variable_name.push_back(table.slot_value(i));
_param_config.slot_alias_to_table[table.slot_key(i)] = table.table_id();
}
std::vector<std::string> tmp_sparse_gradient_variable_name;
for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
tmp_sparse_gradient_variable_name.push_back(table.slot_gradient(i));
}
_param_config.slot_input_vec[table.table_id()] =
std::move(tmp_sparse_variable_name);
_param_config.gradient_var[table.table_id()] =
std::move(tmp_sparse_gradient_variable_name);
_param_config.sparse_table_id.push_back(table.table_id());
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t);
std::vector<std::string> tmp_dense_variable_name;
for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
tmp_dense_variable_name.push_back(table.dense_variable_name(i));
}
std::vector<std::string> tmp_dense_gradient_variable_name;
for (auto i = 0u; i < table.dense_gradient_variable_name_size(); ++i) {
tmp_dense_gradient_variable_name.push_back(
table.dense_gradient_variable_name(i));
}
_param_config.dense_variable_name[table.table_id()] =
std::move(tmp_dense_variable_name);
_param_config.dense_gradient_variable_name[table.table_id()] =
std::move(tmp_dense_gradient_variable_name);
_param_config.dense_table_id.push_back(table.table_id());
_param_config.dense_table_size.push_back(table.fea_dim());
}
} }
void AsyncExecutor::InitModel() { // todo InitModel
for (auto table_id : _param_config.dense_table_id) { void AsyncExecutor::InitModel() {}
std::vector<paddle::ps::Region> regions;
for (auto& t : _param_config.dense_variable_name[table_id]) {
Variable* var = root_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(local_random_engine()) * init_range;
}
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto push_status = _pslib_ptr->_worker_ptr->push_dense_param( // todo SaveModel
regions.data(), regions.size(), table_id); void AsyncExecutor::SaveModel(const std::string& path) {}
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push dense param failed, status[" << status << "]";
exit(-1);
}
}
}
void AsyncExecutor::SaveModel(const std::string& path) {
auto ret = _pslib_ptr->_worker_ptr->flush();
ret.wait();
ret = _pslib_ptr->_worker_ptr->save(path, 0);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0
LOG(FATAL) << "save model failed";
exit(-1);
}
}
void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
if (mode == "mpi") {
DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;
param.threshold = 1;
param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_;
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread =
std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
}
}
#endif
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
...@@ -256,14 +88,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -256,14 +88,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc); &data_feed_desc);
actual_thread_num = thread_num; actual_thread_num_ = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty"); PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
if (actual_thread_num > file_cnt) { if (actual_thread_num_ > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing thread_num = " << file_cnt; << ". Changing thread_num = " << file_cnt;
actual_thread_num = file_cnt; actual_thread_num_ = file_cnt;
} }
/* /*
...@@ -279,12 +111,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -279,12 +111,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
*/ */
// todo: should be factory method for creating datafeed // todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers; std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); /*
PrepareReaders(readers, actual_thread_num_, data_feed_desc, filelist);
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
PrepareDenseThread(mode); PrepareDenseThread(mode);
#endif #endif
*/
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers; std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num); workers.resize(actual_thread_num_);
for (auto& worker : workers) { for (auto& worker : workers) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") { if (mode == "mpi") {
...@@ -298,13 +132,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -298,13 +132,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
} }
// prepare thread resource here // prepare thread resource here
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { /*
for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
CreateThreads(workers[thidx].get(), main_program, readers[thidx], CreateThreads(workers[thidx].get(), main_program, readers[thidx],
fetch_var_names, root_scope_, thidx, debug); fetch_var_names, root_scope_, thidx, debug);
} }
*/
// start executing ops in multiple threads // start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { for (int thidx = 0; thidx < actual_thread_num_; ++thidx) {
if (debug) { if (debug) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer, threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer,
workers[thidx].get())); workers[thidx].get()));
...@@ -317,15 +153,19 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -317,15 +153,19 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
} }
// TODO(guru4elephant): we don't need this
/*
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") { if (mode == "mpi") {
_pull_dense_thread->stop(); _pull_dense_thread->stop();
} }
#endif #endif
*/
VLOG(3) << "start to run from files in async_executor";
VLOG(3) << "Drop current scope kids";
root_scope_->DropKids(); root_scope_->DropKids();
return; return;
} }
} // einit_modelnd namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -25,8 +25,10 @@ limitations under the License. */ ...@@ -25,8 +25,10 @@ limitations under the License. */
#include <typeinfo> #include <typeinfo>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_thread_worker.h" #include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -65,9 +67,10 @@ class AsyncExecutor { ...@@ -65,9 +67,10 @@ class AsyncExecutor {
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_var_names,
const std::string& mode, const bool debug = false); const std::string& mode, const bool debug);
#ifdef PADDLE_WITH_PSLIB
// TODO(guru4elephant): make init server decoupled from executor
void InitServer(const std::string& dist_desc, int index); void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc, void InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, int node_num, const std::vector<uint64_t>& host_sign_list, int node_num,
...@@ -77,31 +80,14 @@ class AsyncExecutor { ...@@ -77,31 +80,14 @@ class AsyncExecutor {
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num); void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
void InitModel(); void InitModel();
void SaveModel(const std::string& path); void SaveModel(const std::string& path);
void InitParamConfig();
#endif
private:
void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index,
const bool debug);
#ifdef PADDLE_WITH_PSLIB
void PrepareDenseThread(const std::string& mode);
#endif
public: public:
#ifdef PADDLE_WITH_PSLIB std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread;
AsyncWorkerParamConfig _param_config;
#endif
Scope* root_scope_; Scope* root_scope_;
platform::Place place_; platform::Place place_;
private: private:
int actual_thread_num; int actual_thread_num_;
}; };
} // namespace framework } // namespace framework
......
...@@ -33,6 +33,14 @@ class BlockingQueue { ...@@ -33,6 +33,14 @@ class BlockingQueue {
cv_.notify_one(); cv_.notify_one();
} }
void Push(T &&item) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(std::move(item));
}
cv_.notify_one();
}
template <typename U> template <typename U>
void Extend(const U &items) { void Extend(const U &items) {
{ {
...@@ -44,6 +52,17 @@ class BlockingQueue { ...@@ -44,6 +52,17 @@ class BlockingQueue {
cv_.notify_all(); cv_.notify_all();
} }
template <typename U>
void Extend(U &&items) {
{
std::lock_guard<std::mutex> g(mutex_);
for (auto &item : items) {
q_.emplace_back(std::move(item));
}
}
cv_.notify_all();
}
std::deque<T> PopAll(size_t ms, bool *timeout) { std::deque<T> PopAll(size_t ms, bool *timeout) {
auto time = auto time =
std::chrono::system_clock::now() + std::chrono::milliseconds(ms); std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
...@@ -64,6 +83,18 @@ class BlockingQueue { ...@@ -64,6 +83,18 @@ class BlockingQueue {
return rc; return rc;
} }
void Pop(T *t) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); });
*t = std::move(q_.front());
q_.pop_front();
}
size_t Size() {
std::lock_guard<std::mutex> lock(mutex_);
return q_.size();
}
private: private:
std::mutex mutex_; std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
......
...@@ -12,23 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,23 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/fluid/framework/data_feed.h"
#ifdef _LINUX
#include <stdio_ext.h>
#endif
#include <utility>
#include "gflags/gflags.h"
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
#include "io/fs.h"
#include "gflags/gflags.h" #include "io/shell.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::vector<std::string> DataFeed::filelist_;
size_t DataFeed::file_idx_;
std::mutex DataFeed::mutex_for_pick_file_;
bool DataFeed::finish_set_filelist_;
void DataFeed::AddFeedVar(Variable* var, const std::string& name) { void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit(); CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
...@@ -39,15 +45,11 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) { ...@@ -39,15 +45,11 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
} }
bool DataFeed::SetFileList(const std::vector<std::string>& files) { bool DataFeed::SetFileList(const std::vector<std::string>& files) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_); std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
CheckInit(); CheckInit();
if (finish_set_filelist_) { // Do not set finish_set_filelist_ flag,
VLOG(3) << "info: you have set the filelist."; // since a user may set file many times after init reader
return false;
}
PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
filelist_.assign(files.begin(), files.end()); filelist_.assign(files.begin(), files.end());
file_idx_ = 0;
finish_set_filelist_ = true; finish_set_filelist_ = true;
return true; return true;
...@@ -59,12 +61,18 @@ void DataFeed::SetBatchSize(int batch_size) { ...@@ -59,12 +61,18 @@ void DataFeed::SetBatchSize(int batch_size) {
} }
bool DataFeed::PickOneFile(std::string* filename) { bool DataFeed::PickOneFile(std::string* filename) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_); PADDLE_ENFORCE(mutex_for_pick_file_ != nullptr,
if (file_idx_ == filelist_.size()) { "should call SetFileListMutex before PickOneFile");
PADDLE_ENFORCE(file_idx_ != nullptr,
"should call SetFileListIndex before PickOneFile");
std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
if (*file_idx_ == filelist_.size()) {
VLOG(3) << "DataFeed::PickOneFile no more file to pick";
return false; return false;
} }
*filename = filelist_[file_idx_++]; VLOG(3) << "file_idx_=" << *file_idx_;
LOG(ERROR) << "pick file:" << *filename; *filename = filelist_[(*file_idx_)++];
// LOG(ERROR) << "pick file:" << *filename;
return true; return true;
} }
...@@ -100,21 +108,24 @@ bool PrivateQueueDataFeed<T>::Start() { ...@@ -100,21 +108,24 @@ bool PrivateQueueDataFeed<T>::Start() {
template <typename T> template <typename T>
void PrivateQueueDataFeed<T>::ReadThread() { void PrivateQueueDataFeed<T>::ReadThread() {
#ifdef _LINUX
std::string filename; std::string filename;
while (PickOneFile(&filename)) { while (PickOneFile(&filename)) {
file_.open(filename.c_str()); // is_text_feed int err_no = 0;
PADDLE_ENFORCE(file_.good(), "Open file<%s> fail.", filename.c_str()); fp_ = fs_open_read(filename, &err_no, pipe_command_);
__fsetlocking(&*fp_, FSETLOCKING_BYCALLER);
T instance; T instance;
while (ParseOneInstance(&instance)) { while (ParseOneInstanceFromPipe(&instance)) {
queue_->Send(instance); queue_->Send(instance);
} }
file_.close();
} }
queue_->Close(); queue_->Close();
#endif
} }
template <typename T> template <typename T>
int PrivateQueueDataFeed<T>::Next() { int PrivateQueueDataFeed<T>::Next() {
#ifdef _LINUX
CheckStart(); CheckStart();
int index = 0; int index = 0;
T instance; T instance;
...@@ -130,11 +141,288 @@ int PrivateQueueDataFeed<T>::Next() { ...@@ -130,11 +141,288 @@ int PrivateQueueDataFeed<T>::Next() {
PutToFeedVec(ins_vec); PutToFeedVec(ins_vec);
} }
return batch_size_; return batch_size_;
#else
return 0;
#endif
} }
#ifdef _WIN32 // explicit instantiation
template class PrivateQueueDataFeed<std::vector<MultiSlotType>>; template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_ = 0;
shuffled_ins_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
shuffled_ins_out_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
fleet_send_batch_size_ = 80000; // hard code here
memory_data_ = nullptr;
mutex_for_update_memory_data_ = nullptr;
this->file_idx_ = nullptr;
this->mutex_for_pick_file_ = nullptr;
}
template <typename T>
bool InMemoryDataFeed<T>::Start() {
#ifdef _LINUX
DataFeed::CheckSetFileList();
if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) {
FillMemoryDataToChannel();
}
#endif #endif
DataFeed::finish_start_ = true;
return true;
}
template <typename T>
int InMemoryDataFeed<T>::Next() {
#ifdef _LINUX
DataFeed::CheckStart();
std::shared_ptr<paddle::framework::BlockingQueue<T>> in_channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> out_channel = nullptr;
if (cur_channel_ == 0) {
in_channel = shuffled_ins_;
out_channel = shuffled_ins_out_;
} else {
in_channel = shuffled_ins_out_;
out_channel = shuffled_ins_;
}
CHECK(in_channel != nullptr);
CHECK(out_channel != nullptr);
VLOG(3) << "in_channel size=" << in_channel->Size()
<< ", out_channel size=" << out_channel->Size()
<< ", thread_id=" << thread_id_;
int index = 0;
T instance;
T ins_vec;
while (index < DataFeed::default_batch_size_) {
if (in_channel->Size() == 0) {
break;
}
in_channel->Pop(&instance);
AddInstanceToInsVec(&ins_vec, instance, index++);
out_channel->Push(std::move(instance));
}
DataFeed::batch_size_ = index;
VLOG(3) << "batch_size_=" << DataFeed::batch_size_
<< ", thread_id=" << thread_id_;
if (DataFeed::batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
cur_channel_ = 1 - cur_channel_;
}
return DataFeed::batch_size_;
#else
return 0;
#endif
}
template <typename T>
void InMemoryDataFeed<T>::SetMemoryData(void* memory_data) {
memory_data_ = static_cast<std::vector<T>*>(memory_data);
}
template <typename T>
void InMemoryDataFeed<T>::SetMemoryDataMutex(std::mutex* mutex) {
mutex_for_update_memory_data_ = mutex;
}
template <typename T>
void InMemoryDataFeed<T>::SetThreadId(int thread_id) {
thread_id_ = thread_id;
}
template <typename T>
void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
template <typename T>
void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
#ifdef _LINUX
std::vector<T> ins;
DeserializeIns(&ins, ins_str);
shuffled_ins_->Extend(std::move(ins));
VLOG(3) << "PutInsToChannel put ins num=" << ins.size()
<< " to channel, channel size=" << shuffled_ins_->Size()
<< " thread_id=" << thread_id_;
#endif
}
template <typename T>
void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
#ifdef _LINUX
VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_;
auto interval = GetMemoryDataInterval();
VLOG(3) << "memory data size=" << memory_data_->size()
<< ", fill data from [" << interval.first << ", " << interval.second
<< "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) {
T& t = (*memory_data_)[i];
shuffled_ins_->Push(std::move(t));
}
#endif
}
template <typename T>
void InMemoryDataFeed<T>::FillChannelToMemoryData() {
#ifdef _LINUX
VLOG(3) << "FillChannelToMemoryData, thread_id=" << thread_id_;
std::vector<T> local_vec;
std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> pre_channel = nullptr;
if (cur_channel_ == 0) {
channel = shuffled_ins_;
pre_channel = shuffled_ins_out_;
} else {
channel = shuffled_ins_out_;
pre_channel = shuffled_ins_;
}
CHECK(channel != nullptr);
CHECK(pre_channel != nullptr);
CHECK_EQ(pre_channel->Size(), 0);
local_vec.resize(channel->Size());
for (int64_t i = 0; i < local_vec.size(); ++i) {
channel->Pop(&local_vec[i]);
}
VLOG(3) << "local_vec size=" << local_vec.size()
<< ", thread_id=" << thread_id_;
{
std::lock_guard<std::mutex> g(*mutex_for_update_memory_data_);
VLOG(3) << "before insert, memory_data_ size=" << memory_data_->size()
<< ", thread_id=" << thread_id_;
memory_data_->insert(memory_data_->end(), local_vec.begin(),
local_vec.end());
VLOG(3) << "after insert memory_data_ size=" << memory_data_->size()
<< ", thread_id=" << thread_id_;
}
std::vector<T>().swap(local_vec);
#endif
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX
VLOG(3) << "LoadIntoMemory() begin, thread_id=" << thread_id_;
std::vector<T> local_vec;
std::string filename;
while (DataFeed::PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename
<< ", thread_id=" << thread_id_;
int err_no = 0;
PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
CHECK(PrivateQueueDataFeed<T>::fp_ != nullptr);
__fsetlocking(&*PrivateQueueDataFeed<T>::fp_, FSETLOCKING_BYCALLER);
T instance;
platform::Timer timeline;
timeline.Start();
while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance);
}
timeline.Pause();
VLOG(3) << "LoadIntoMemory() read all lines, file=" << filename
<< ", cost time=" << timeline.ElapsedSec()
<< " seconds, thread_id=" << thread_id_;
{
std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_);
timeline.Start();
memory_data_->insert(memory_data_->end(),
std::make_move_iterator(local_vec.begin()),
std::make_move_iterator(local_vec.end()));
timeline.Pause();
VLOG(3) << "LoadIntoMemory() memory_data insert, cost time="
<< timeline.ElapsedSec() << " seconds, thread_id=" << thread_id_;
}
local_vec.clear();
}
std::vector<T>().swap(local_vec);
VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_;
#endif
}
template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() {
#ifdef _LINUX
VLOG(3) << "LocalShuffle() begin, thread_id=" << thread_id_;
FillMemoryDataToChannel();
VLOG(3) << "LocalShuffle() end, thread_id=" << thread_id_;
#endif
}
template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle() {
#ifdef _LINUX
VLOG(3) << "GlobalShuffle() begin, thread_id=" << thread_id_;
auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::vector<T*>> send_vec(trainer_num_);
for (auto& vec : send_vec) {
vec.reserve(fleet_send_batch_size_);
}
std::vector<std::future<int32_t>> total_status;
auto interval = GetMemoryDataInterval();
VLOG(3) << "global shuffle data from [" << interval.first << ", "
<< interval.second << "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) {
// if get ins id, can also use hash
// std::string ins_id = memory_data_[i].ins_id;
int64_t random_num = rand_r(&rand_seed);
int64_t node_id = random_num % trainer_num_;
send_vec[node_id].push_back(&((*memory_data_)[i]));
if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_vec.size(); ++j) {
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
<< ", ins num=" << send_vec[j].size() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
send_vec[j].clear();
total_status.push_back(std::move(ret));
}
}
}
for (int j = 0; j < send_vec.size(); ++j) {
if (send_vec[j].size() != 0) {
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
total_status.push_back(std::move(ret));
}
std::vector<T*>().swap(send_vec[j]);
}
for (auto& t : total_status) {
t.wait();
}
VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_;
#endif
}
template <typename T>
std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() {
int64_t start = 0;
int64_t end = 0;
int64_t size = memory_data_->size();
for (int64_t i = 0; i <= static_cast<int64_t>(thread_id_); ++i) {
int64_t len = size / static_cast<int64_t>(thread_num_) +
(i < (size % static_cast<int64_t>(thread_num_)));
start = end;
end += len;
}
return std::make_pair(start, end);
}
// explicit instantiation
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
void MultiSlotDataFeed::Init( void MultiSlotDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) { const paddle::framework::DataFeedDesc& data_feed_desc) {
...@@ -165,10 +453,32 @@ void MultiSlotDataFeed::Init( ...@@ -165,10 +453,32 @@ void MultiSlotDataFeed::Init(
} }
} }
feed_vec_.resize(use_slots_.size()); feed_vec_.resize(use_slots_.size());
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true; finish_init_ = true;
} }
void MultiSlotDataFeed::ReadThread() {
#ifdef _LINUX
std::string filename;
while (PickOneFile(&filename)) {
int err_no = 0;
fp_ = fs_open_read(filename, &err_no, pipe_command_);
CHECK(fp_ != nullptr);
__fsetlocking(&*fp_, FSETLOCKING_BYCALLER);
std::vector<MultiSlotType> instance;
int ins_num = 0;
while (ParseOneInstanceFromPipe(&instance)) {
ins_num++;
queue_->Send(instance);
}
VLOG(3) << "filename: " << filename << " inst num: " << ins_num;
}
queue_->Close();
#endif
}
bool MultiSlotDataFeed::CheckFile(const char* filename) { bool MultiSlotDataFeed::CheckFile(const char* filename) {
#ifdef _LINUX
CheckInit(); // get info of slots CheckInit(); // get info of slots
std::ifstream fin(filename); std::ifstream fin(filename);
if (!fin.good()) { if (!fin.good()) {
...@@ -276,10 +586,68 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) { ...@@ -276,10 +586,68 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
} }
VLOG(3) << "instances cout: " << instance_cout; VLOG(3) << "instances cout: " << instance_cout;
VLOG(3) << "The file format is correct"; VLOG(3) << "The file format is correct";
#endif
return true;
}
bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
thread_local string::LineFileReader reader;
if (!reader.getline(&*(fp_.get()))) {
return false;
} else {
int use_slots_num = use_slots_.size();
instance->resize(use_slots_num);
const char* str = reader.get();
std::string line = std::string(str);
// VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
(*instance)[idx].AddValue(feasign);
}
} else if ((*instance)[idx].GetType()[0] == 'u') { // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
(*instance)[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
// pos = line.find_first_of(' ', pos + 1);
while (line[pos + 1] != ' ') {
pos++;
}
}
}
}
return true;
}
#else
return true; return true;
#endif
} }
bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
std::string line; std::string line;
if (getline(file_, line)) { if (getline(file_, line)) {
int use_slots_num = use_slots_.size(); int use_slots_num = use_slots_.size();
...@@ -322,12 +690,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { ...@@ -322,12 +690,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
} else { } else {
return false; return false;
} }
return true; #endif
return false;
} }
void MultiSlotDataFeed::AddInstanceToInsVec( void MultiSlotDataFeed::AddInstanceToInsVec(
std::vector<MultiSlotType>* ins_vec, std::vector<MultiSlotType>* ins_vec,
const std::vector<MultiSlotType>& instance, int index) { const std::vector<MultiSlotType>& instance, int index) {
#ifdef _LINUX
if (index == 0) { if (index == 0) {
ins_vec->resize(instance.size()); ins_vec->resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) { for (size_t i = 0; i < instance.size(); ++i) {
...@@ -339,10 +709,200 @@ void MultiSlotDataFeed::AddInstanceToInsVec( ...@@ -339,10 +709,200 @@ void MultiSlotDataFeed::AddInstanceToInsVec(
for (size_t i = 0; i < instance.size(); ++i) { for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]); (*ins_vec)[i].AddIns(instance[i]);
} }
#endif
} }
void MultiSlotDataFeed::PutToFeedVec( void MultiSlotDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) { const std::vector<MultiSlotType>& ins_vec) {
#ifdef _LINUX
for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData();
float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int dim = total_instance / batch_size_;
feed_vec_[i]->Resize({batch_size_, dim});
}
}
#endif
}
void MultiSlotInMemoryDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
SetQueueSize(data_feed_desc.batch_size());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
}
}
feed_vec_.resize(use_slots_.size());
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true;
}
bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
thread_local string::LineFileReader reader;
if (!reader.getline(&*(fp_.get()))) {
return false;
} else {
int use_slots_num = use_slots_.size();
instance->resize(use_slots_num);
const char* str = reader.get();
std::string line = std::string(str);
// VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
(*instance)[idx].AddValue(feasign);
}
} else if ((*instance)[idx].GetType()[0] == 'u') { // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
(*instance)[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
// pos = line.find_first_of(' ', pos + 1);
while (line[pos + 1] != ' ') {
pos++;
}
}
}
}
return true;
}
#else
return false;
#endif
}
bool MultiSlotInMemoryDataFeed::ParseOneInstance(
std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
std::string line;
if (getline(file_, line)) {
int use_slots_num = use_slots_.size();
instance->resize(use_slots_num);
VLOG(3) << line;
// parse line
const char* str = line.c_str();
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
(*instance)[idx].AddValue(feasign);
}
} else if ((*instance)[idx].GetType()[0] == 'u') { // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
(*instance)[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
pos = line.find_first_of(' ', pos + 1);
}
}
}
} else {
return false;
}
#endif
return false;
}
void MultiSlotInMemoryDataFeed::AddInstanceToInsVec(
std::vector<MultiSlotType>* ins_vec,
const std::vector<MultiSlotType>& instance, int index) {
#ifdef _LINUX
if (index == 0) {
ins_vec->resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].Init(instance[i].GetType());
(*ins_vec)[i].InitOffset();
}
}
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]);
}
#endif
}
void MultiSlotInMemoryDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) {
#ifdef _LINUX
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec[i].GetType(); const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset(); const auto& offset = ins_vec[i].GetOffset();
...@@ -368,6 +928,20 @@ void MultiSlotDataFeed::PutToFeedVec( ...@@ -368,6 +928,20 @@ void MultiSlotDataFeed::PutToFeedVec(
feed_vec_[i]->Resize({batch_size_, dim}); feed_vec_[i]->Resize({batch_size_, dim});
} }
} }
#endif
}
// todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns(
const std::vector<std::vector<MultiSlotType>*>& ins, std::string* str) {
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Serialize(ins, str);
}
// todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns(
std::vector<std::vector<MultiSlotType>>* ins, const std::string& str) {
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Deserialize(ins, str);
} }
} // namespace framework } // namespace framework
......
...@@ -15,17 +15,23 @@ limitations under the License. */ ...@@ -15,17 +15,23 @@ limitations under the License. */
#pragma once #pragma once
#include <fstream> #include <fstream>
#include <future> // NOLINT
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <sstream>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -48,7 +54,10 @@ namespace framework { ...@@ -48,7 +54,10 @@ namespace framework {
// } // }
class DataFeed { class DataFeed {
public: public:
DataFeed() {} DataFeed() {
mutex_for_pick_file_ = nullptr;
file_idx_ = nullptr;
}
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0; virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
...@@ -59,6 +68,7 @@ class DataFeed { ...@@ -59,6 +68,7 @@ class DataFeed {
// Otherwise, Init() function will init finish_set_filelist_ flag. // Otherwise, Init() function will init finish_set_filelist_ flag.
virtual bool SetFileList(const std::vector<std::string>& files); virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0; virtual bool Start() = 0;
// The trainer calls the Next() function, and the DataFeed will load a new // The trainer calls the Next() function, and the DataFeed will load a new
// batch to the feed_vec. The return value of this function is the batch // batch to the feed_vec. The return value of this function is the batch
// size of the current batch. // size of the current batch.
...@@ -74,6 +84,36 @@ class DataFeed { ...@@ -74,6 +84,36 @@ class DataFeed {
// This function is used for binding feed_vec memory // This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name); virtual void AddFeedVar(Variable* var, const std::string& name);
// This function will do nothing at default
virtual void SetMemoryData(void* memory_data) {}
// This function will do nothing at default
virtual void SetMemoryDataMutex(std::mutex* mutex) {}
// This function will do nothing at default
virtual void SetThreadId(int thread_id) {}
// This function will do nothing at default
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
virtual void LocalShuffle() {
PADDLE_THROW("This function(LocalShuffle) is not implemented.");
}
virtual void GlobalShuffle() {
PADDLE_THROW("This function(GlobalShuffle) is not implemented.");
}
// This function will do nothing at default
virtual void FillMemoryDataToChannel() {}
// This function will do nothing at default
virtual void FillChannelToMemoryData() {}
// This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) {}
protected: protected:
// The following three functions are used to check if it is executed in this // The following three functions are used to check if it is executed in this
// order: // order:
...@@ -87,9 +127,9 @@ class DataFeed { ...@@ -87,9 +127,9 @@ class DataFeed {
// safe). // safe).
virtual bool PickOneFile(std::string* filename); virtual bool PickOneFile(std::string* filename);
static std::vector<std::string> filelist_; std::vector<std::string> filelist_;
static size_t file_idx_; size_t* file_idx_;
static std::mutex mutex_for_pick_file_; std::mutex* mutex_for_pick_file_;
// the alias of used slots, and its order is determined by // the alias of used slots, and its order is determined by
// data_feed_desc(proto object) // data_feed_desc(proto object)
...@@ -112,8 +152,9 @@ class DataFeed { ...@@ -112,8 +152,9 @@ class DataFeed {
int batch_size_; int batch_size_;
bool finish_init_; bool finish_init_;
static bool finish_set_filelist_; bool finish_set_filelist_;
bool finish_start_; bool finish_start_;
std::string pipe_command_;
}; };
// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds. // PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
...@@ -136,6 +177,7 @@ class PrivateQueueDataFeed : public DataFeed { ...@@ -136,6 +177,7 @@ class PrivateQueueDataFeed : public DataFeed {
virtual void SetQueueSize(int queue_size); virtual void SetQueueSize(int queue_size);
// The reading and parsing method called in the ReadThread. // The reading and parsing method called in the ReadThread.
virtual bool ParseOneInstance(T* instance) = 0; virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
// This function is used to put instance to vec_ins // This function is used to put instance to vec_ins
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
int index) = 0; int index) = 0;
...@@ -150,11 +192,58 @@ class PrivateQueueDataFeed : public DataFeed { ...@@ -150,11 +192,58 @@ class PrivateQueueDataFeed : public DataFeed {
// ifstream one line and one line parse: 6034 ms // ifstream one line and one line parse: 6034 ms
// fread one buffer and one buffer parse: 7097 ms // fread one buffer and one buffer parse: 7097 ms
std::ifstream file_; std::ifstream file_;
std::shared_ptr<FILE> fp_;
size_t queue_size_; size_t queue_size_;
string::LineFileReader reader_;
// The queue for store parsed data // The queue for store parsed data
std::unique_ptr<paddle::operators::reader::BlockingQueue<T>> queue_; std::unique_ptr<paddle::operators::reader::BlockingQueue<T>> queue_;
}; };
template <typename T>
class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
InMemoryDataFeed();
virtual ~InMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
virtual void SetMemoryData(void* memory_data);
virtual void SetMemoryDataMutex(std::mutex* mutex);
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void PutInsToChannel(const std::string& ins_str);
virtual void FillMemoryDataToChannel();
virtual void FillChannelToMemoryData();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const std::vector<T*>& ins, std::string* str) = 0;
virtual void DeserializeIns(std::vector<T>* ins, const std::string& str) = 0;
virtual std::pair<int64_t, int64_t> GetMemoryDataInterval();
int thread_id_;
int thread_num_;
int trainer_num_;
uint32_t rand_seed;
std::vector<T>* memory_data_;
std::mutex* mutex_for_update_memory_data_;
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int cur_channel_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed // This class define the data type of instance(ins_vec) in MultiSlotDataFeed
class MultiSlotType { class MultiSlotType {
public: public:
...@@ -176,6 +265,7 @@ class MultiSlotType { ...@@ -176,6 +265,7 @@ class MultiSlotType {
offset_[0] = 0; offset_[0] = 0;
} }
const std::vector<size_t>& GetOffset() const { return offset_; } const std::vector<size_t>& GetOffset() const { return offset_; }
std::vector<size_t>& MutableOffset() { return offset_; }
void AddValue(const float v) { void AddValue(const float v) {
CheckFloat(); CheckFloat();
float_feasign_.push_back(v); float_feasign_.push_back(v);
...@@ -198,8 +288,33 @@ class MultiSlotType { ...@@ -198,8 +288,33 @@ class MultiSlotType {
} }
} }
const std::vector<float>& GetFloatData() const { return float_feasign_; } const std::vector<float>& GetFloatData() const { return float_feasign_; }
std::vector<float>& MutableFloatData() { return float_feasign_; }
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; } const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
const std::string& GetType() const { return type_; } const std::string& GetType() const { return type_; }
std::string& MutableType() { return type_; }
std::string DebugString() {
std::stringstream ss;
ss << "\ntype: " << type_ << "\n";
ss << "offset: ";
ss << "[";
for (const size_t& i : offset_) {
ss << offset_[i] << ",";
}
ss << "]\ndata: [";
if (type_[0] == 'f') {
for (const float& i : float_feasign_) {
ss << i << ",";
}
} else {
for (const uint64_t& i : uint64_feasign_) {
ss << i << ",";
}
}
ss << "]\n";
return ss.str();
}
private: private:
void CheckType(const std::string& type) const { void CheckType(const std::string& type) const {
...@@ -228,13 +343,37 @@ class MultiSlotDataFeed ...@@ -228,13 +343,37 @@ class MultiSlotDataFeed
virtual ~MultiSlotDataFeed() {} virtual ~MultiSlotDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc); virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual bool CheckFile(const char* filename); virtual bool CheckFile(const char* filename);
// virtual void ReadThread();
protected: protected:
virtual void ReadThread();
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins, virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance, const std::vector<MultiSlotType>& instance,
int index); int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance); virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec); virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
}; };
class MultiSlotInMemoryDataFeed
: public InMemoryDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance,
int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<std::vector<MultiSlotType>*>& ins,
std::string* str);
virtual void DeserializeIns(std::vector<std::vector<MultiSlotType>>* ins,
const std::string& str);
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -27,4 +27,6 @@ message DataFeedDesc { ...@@ -27,4 +27,6 @@ message DataFeedDesc {
optional string name = 1; optional string name = 1;
optional int32 batch_size = 2 [ default = 32 ]; optional int32 batch_size = 2 [ default = 32 ];
optional MultiSlotDesc multi_slot_desc = 3; optional MultiSlotDesc multi_slot_desc = 3;
optional string pipe_command = 4;
optional int32 thread_num = 5;
} }
...@@ -54,11 +54,15 @@ std::string DataFeedFactory::DataFeedTypeList() { ...@@ -54,11 +54,15 @@ std::string DataFeedFactory::DataFeedTypeList() {
std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) { std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) { if (g_data_feed_map.count(data_feed_class) < 1) {
LOG(WARNING) << "Your DataFeed " << data_feed_class
<< "is not supported currently";
LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList();
exit(-1); exit(-1);
} }
return g_data_feed_map[data_feed_class](); return g_data_feed_map[data_feed_class]();
} }
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -324,7 +324,7 @@ TEST(DataFeed, MultiSlotUnitTest) { ...@@ -324,7 +324,7 @@ TEST(DataFeed, MultiSlotUnitTest) {
load_datafeed_param_from_file(protofile); load_datafeed_param_from_file(protofile);
std::vector<MultiTypeSet> reader_elem_set; std::vector<MultiTypeSet> reader_elem_set;
std::vector<MultiTypeSet> file_elem_set; std::vector<MultiTypeSet> file_elem_set;
GetElemSetFromReader(&reader_elem_set, data_feed_desc, filelist, 4); // GetElemSetFromReader(&reader_elem_set, data_feed_desc, filelist, 4);
GetElemSetFromFile(&file_elem_set, data_feed_desc, filelist); // GetElemSetFromFile(&file_elem_set, data_feed_desc, filelist);
CheckIsUnorderedSame(reader_elem_set, file_elem_set); // CheckIsUnorderedSame(reader_elem_set, file_elem_set);
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/data_set.h"
#include <random>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/timer.h"
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
namespace paddle {
namespace framework {
// constructor
template <typename T>
DatasetImpl<T>::DatasetImpl() {
thread_num_ = 1;
trainer_num_ = 1;
file_idx_ = 0;
}
// set filelist, file_idx_ will reset to zero.
template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist;
file_idx_ = 0;
}
// set expect thread num. actually it may change
template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) {
VLOG(3) << "SetThreadNum thread_num=" << thread_num;
thread_num_ = thread_num;
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerNum
template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
// should inform reader of trainer_num directly
for (auto reader : readers_) {
reader->SetTrainerNum(trainer_num);
}
}
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
fs_name_ = fs_name;
fs_ugi_ = fs_ugi;
std::string cmd = std::string("hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
paddle::framework::hdfs_set_command(cmd);
}
template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc_);
}
// readers_.size() may not be equal to thread_num_,
// it changes when filelist_.size() < thread_num_
template <typename T>
std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
DatasetImpl<T>::GetReaders() {
return readers_;
}
// if sent message between workers, should first call this function
template <typename T>
void DatasetImpl<T>::RegisterClientToClientMsgHandler() {
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
VLOG(3) << "RegisterClientToClientMsgHandler done";
}
// load data into memory, Dataset hold this memory,
// which will later be fed into readers' channel
template <typename T>
void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> load_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
}
for (std::thread& t : load_threads) {
t.join();
}
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
<< ", memory data size=" << memory_data_.size()
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
// release memory data
template <typename T>
void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}
// do local shuffle
template <typename T>
void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::LocalShuffle, readers_[i].get()));
}
for (std::thread& t : local_shuffle_threads) {
t.join();
}
std::vector<T>().swap(memory_data_);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::GlobalShuffle, readers_[i].get()));
}
for (std::thread& t : global_shuffle_threads) {
t.join();
}
std::vector<T>().swap(memory_data_);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0";
int file_cnt = filelist_.size();
int memory_data_size = memory_data_.size();
if (memory_data_size != 0 && thread_num_ > memory_data_size) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", memory data size = " << memory_data_size
<< ". Changing Dataset thread num = " << memory_data_size;
thread_num_ = memory_data_size;
} else if (file_cnt != 0 && thread_num_ > file_cnt) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing Dataset thread num = " << file_cnt;
thread_num_ = file_cnt;
}
VLOG(3) << "thread_num in Readers: " << thread_num_;
VLOG(3) << "readers size: " << readers_.size();
VLOG(3) << "Filelist size in readers: " << filelist_.size();
if (readers_.size() != 0) {
return;
}
VLOG(3) << "data feed class name: " << data_feed_desc_.name();
for (int i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_);
readers_.back()->SetMemoryData(&memory_data_);
readers_.back()->SetMemoryDataMutex(&mutex_for_update_memory_data_);
readers_.back()->SetThreadId(i);
readers_.back()->SetThreadNum(thread_num_);
readers_.back()->SetTrainerNum(trainer_num_);
readers_.back()->SetFileListMutex(&mutex_for_pick_file_);
readers_.back()->SetFileListIndex(&file_idx_);
readers_.back()->SetFileList(filelist_);
}
}
template <typename T>
void DatasetImpl<T>::DestroyReaders() {
VLOG(3) << "Calling DestroyReaders()";
// clear memory_data_ before fill it
// because if LoadIntoMemory but no Shuffle,
// memory_data_ has empty data which has been std::move to channel
if (memory_data_.size() != 0) {
std::vector<T>().swap(memory_data_);
}
std::vector<std::thread> fill_threads;
for (int i = 0; i < thread_num_; ++i) {
fill_threads.push_back(
std::thread(&paddle::framework::DataFeed::FillChannelToMemoryData,
readers_[i].get()));
}
for (std::thread& t : fill_threads) {
t.join();
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "readers size: " << readers_.size();
// if memory_data_ is empty, which means it's not InMemory mode,
// so the next epoch should read all data again
if (memory_data_.size() == 0) {
file_idx_ = 0;
}
}
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
#ifdef _LINUX
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << msg.length();
auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = rand_r(&rand_seed) % thread_num_;
VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg);
#endif
return 0;
}
// explicit instantiation
template class DatasetImpl<std::vector<MultiSlotType>>;
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
// Dataset is a abstract class, which defines user interfaces
// Example Usage:
// Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
// dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
// dataset->SetThreadNum(1)
// dataset->CreateReaders();
// dataset->SetDataFeedDesc(your_data_feed_desc);
// dataset->LoadIntoMemory();
// dataset->SetTrainerNum(2);
// dataset->GlobalShuffle();
class Dataset {
public:
Dataset() {}
virtual ~Dataset() {}
// set file list
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
// set readers' num
virtual void SetThreadNum(int thread_num) = 0;
// set workers' num
virtual void SetTrainerNum(int trainer_num) = 0;
// set fs name and ugi
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
// set data fedd desc, which contains:
// data feed name, batch size, slots
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// get file list
virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num
virtual int GetThreadNum() = 0;
// get worker num
virtual int GetTrainerNum() = 0;
// get hdfs config
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
// get data fedd desc
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
// get readers, the reader num depend both on thread num
// and filelist size
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders() = 0;
// register message handler between workers
virtual void RegisterClientToClientMsgHandler() = 0;
// load all data into memory
virtual void LoadIntoMemory() = 0;
// release all memory data
virtual void ReleaseMemory() = 0;
// local shuffle data
virtual void LocalShuffle() = 0;
// global shuffle data
virtual void GlobalShuffle() = 0;
// create readers
virtual void CreateReaders() = 0;
// destroy readers
virtual void DestroyReaders() = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
};
// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
template <typename T>
class DatasetImpl : public Dataset {
public:
DatasetImpl();
virtual ~DatasetImpl() {}
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
}
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders();
virtual void RegisterClientToClientMsgHandler();
virtual void LoadIntoMemory();
virtual void ReleaseMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
virtual void CreateReaders();
virtual void DestroyReaders();
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<T> memory_data_;
std::mutex mutex_for_update_memory_data_;
int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
int trainer_num_;
std::vector<std::string> filelist_;
size_t file_idx_;
std::mutex mutex_for_pick_file_;
std::string fs_name_;
std::string fs_ugi_;
unsigned int rand_seed;
};
// use std::vector<MultiSlotType> as data type
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
public:
MultiSlotDataset() {}
virtual ~MultiSlotDataset() {}
};
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/dataset_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)();
typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap;
datasetMap g_dataset_map;
#define REGISTER_DATASET_CLASS(dataset_class) \
namespace { \
std::shared_ptr<Dataset> Creator_##dataset_class() { \
return std::shared_ptr<Dataset>(new dataset_class); \
} \
class __Registerer_##dataset_class { \
public: \
__Registerer_##dataset_class() { \
g_dataset_map[#dataset_class] = &Creator_##dataset_class; \
} \
}; \
__Registerer_##dataset_class g_registerer_##dataset_class; \
} // namespace
std::string DatasetFactory::DatasetTypeList() {
std::string dataset_types;
for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end(); ++iter) {
if (iter != g_dataset_map.begin()) {
dataset_types += ", ";
}
dataset_types += iter->first;
}
return dataset_types;
}
std::shared_ptr<Dataset> DatasetFactory::CreateDataset(
std::string dataset_class) {
if (g_dataset_map.count(dataset_class) < 1) {
LOG(WARNING) << "Your Dataset " << dataset_class
<< "is not supported currently";
LOG(WARNING) << "Supported Dataset: " << DatasetTypeList();
exit(-1);
}
return g_dataset_map[dataset_class]();
}
REGISTER_DATASET_CLASS(MultiSlotDataset);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
class DatasetFactory {
public:
static std::string DatasetTypeList();
static std::shared_ptr<Dataset> CreateDataset(std::string dataset_class);
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
void DeviceWorker::SetRootScope(Scope* root_scope) { root_scope_ = root_scope; }
void DeviceWorker::SetDataFeed(const std::shared_ptr<DataFeed>& data_feed) {
device_reader_ = data_feed;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
class PullDenseWorker {
public:
virtual ~PullDenseWorker() {}
virtual void Initialize(const TrainerDesc& param);
int Start();
void Stop();
void SetRootScope(Scope* scope) { root_scope_ = scope; }
void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec);
static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker());
}
return s_instance_;
}
private:
PullDenseWorker() : root_scope_(NULL) {}
void Run();
bool CheckUpdateParam(uint64_t table_id);
private:
static std::shared_ptr<PullDenseWorker> s_instance_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
PullDenseWorkerParameter param_;
DownpourWorkerParameter dwp_param_;
Scope* root_scope_;
bool running_;
static std::map<uint64_t, uint64_t> last_versions_;
static std::map<uint64_t, uint64_t> current_version_;
static std::mutex mutex_for_version_;
static std::map<uint64_t, std::vector<uint64_t>> training_versions_;
static std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::thread t_;
int thread_num_;
int sleep_time_ms_;
int threshold_;
std::vector<::std::future<int32_t>> pull_dense_status_;
uint32_t pull_dense_fail_times_ = 0;
std::vector<float> base_norm_param_;
std::vector<float> mean_;
std::vector<float> scale_;
float squared_sum_epsilon_ = 1e-4;
std::mutex mutex_for_mean_scale_;
float total_batch_num_ = 0;
};
// should incorporate different type of device
class DeviceWorker {
public:
DeviceWorker() {}
virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0;
virtual void TrainFiles() = 0;
virtual void PrintFetchVars() = 0;
virtual void TrainFilesWithProfiler() = 0;
virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
// will make this zero copy in the future
virtual void BindingDataFeedMemory() = 0;
virtual void SetRootScope(Scope* root_scope);
virtual void SetDataFeed(const std::shared_ptr<DataFeed>& data_feed);
virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place;
}
protected:
Scope* root_scope_;
paddle::platform::Place place_;
std::shared_ptr<DataFeed> device_reader_;
int64_t batch_num_;
FetchConfig fetch_config_;
};
class CPUWorkerBase : public DeviceWorker {
public:
CPUWorkerBase() {}
virtual ~CPUWorkerBase() {}
virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
virtual void TrainFiles() = 0;
virtual void TrainFilesWithProfiler() {}
virtual void PrintFetchVars() {}
virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}
protected:
int thread_id_;
};
class HogwildWorker : public CPUWorkerBase {
public:
HogwildWorker() {}
virtual ~HogwildWorker() {}
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
virtual void PrintFetchVars();
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory();
protected:
void CreateThreadOperators(const ProgramDesc& program);
void CreateThreadScope(const ProgramDesc& program);
std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_;
Scope* thread_scope_;
HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_;
};
class DownpourWorker : public HogwildWorker {
public:
DownpourWorker() {}
virtual ~DownpourWorker() {}
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
protected:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
void FillSparseValue(size_t table_id);
void PushGradients();
void CollectLabelInfo(size_t table_id);
private:
bool need_to_push_dense_;
bool need_to_push_sparse_;
DownpourWorkerParameter param_;
// just save the value in param_ for easy access
std::map<uint64_t, std::string> label_var_name_;
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
// feasign
std::map<uint64_t, std::vector<uint64_t>> features_;
// feasign stats
std::map<uint64_t, std::vector<float>> feature_labels_;
// feasign embedding
std::map<uint64_t, std::vector<std::vector<float>>> feature_values_;
// feasign embedding gradient
std::map<uint64_t, std::vector<std::vector<float>>> feature_grads_;
// skipped ops
std::vector<std::string> skip_ops_;
std::shared_ptr<PullDenseWorker> _pull_dense_worker;
std::vector<::std::future<int32_t>> push_sparse_status_;
std::vector<::std::future<int32_t>> push_dense_status_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
namespace paddle {
namespace framework {
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
typedef std::unordered_map<std::string, Createdevice_workerFunction>
device_workerMap;
device_workerMap g_device_worker_map;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std::string DeviceWorkerFactory::DeviceWorkerTypeList() {
std::string device_worker_types;
for (auto iter = g_device_worker_map.begin();
iter != g_device_worker_map.end(); ++iter) {
if (iter != g_device_worker_map.begin()) {
device_worker_types += ", ";
}
device_worker_types += iter->first;
}
return device_worker_types;
}
std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
std::string device_worker_class) {
if (g_device_worker_map.count(device_worker_class) < 1) {
exit(-1);
}
return g_device_worker_map[device_worker_class]();
}
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
class DeviceWorkerFactory {
public:
static std::string DeviceWorkerTypeList();
static std::shared_ptr<DeviceWorker> CreateDeviceWorker(
std::string device_worker_class);
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
TEST() {
// create hogwild device worker
}
}
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
dataset->CreateReaders();
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
thread_num_ = readers.size();
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc);
}
VLOG(3) << "going to initialize pull dense worker";
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
VLOG(3) << "initialize pull dense worker";
SetDebug(trainer_desc.debug());
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
pull_dense_worker_->SetRootScope(root_scope_);
pull_dense_worker_->Start();
VLOG(3) << "init other env done.";
}
void DistMultiTrainer::Run() {
for (int thidx = 0; thidx < thread_num_; ++thidx) {
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
workers_[thidx].get()));
}
}
}
void DistMultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
pull_dense_worker_->Stop();
dataset_ptr_->DestroyReaders();
root_scope_->DropKids();
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace paddle {
namespace framework {
void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(i).table_id());
TableParameter table = param_.sparse_table(i);
sparse_key_names_[table_id].resize(table.sparse_key_name_size());
for (size_t j = 0; j < table.sparse_key_name_size(); ++j) {
sparse_key_names_[table_id][j] = table.sparse_key_name(j);
}
sparse_value_names_[table_id].resize(table.sparse_value_name_size());
for (size_t j = 0; j < table.sparse_value_name_size(); ++j) {
sparse_value_names_[table_id][j] = table.sparse_value_name(j);
}
sparse_grad_names_[table_id].resize(table.sparse_grad_name_size());
for (size_t j = 0; j < table.sparse_grad_name_size(); ++j) {
sparse_grad_names_[table_id][j] = table.sparse_grad_name(j);
}
label_var_name_[table_id] = table.label_var_name();
}
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id());
auto table = param_.dense_table(i);
dense_value_names_[table_id].resize(table.dense_value_name_size());
for (size_t j = 0; j < table.dense_value_name_size(); ++j) {
dense_value_names_[table_id][j] = table.dense_value_name(j);
}
dense_grad_names_[table_id].resize(table.dense_grad_name_size());
for (size_t j = 0; j < table.dense_grad_name_size(); ++j) {
dense_grad_names_[table_id][j] = table.dense_grad_name(j);
}
}
skip_ops_.resize(param_.skip_ops_size());
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
need_to_push_sparse_ = param_.push_sparse();
need_to_push_dense_ = param_.push_dense();
fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config();
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == table_id) {
table = i;
break;
}
}
auto& feature = features_[table_id];
auto& feature_label = feature_labels_[table_id];
feature_label.resize(feature.size());
Variable* var = thread_scope_->FindVar(label_var_name_[table_id]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* label_ptr = tensor->data<int64_t>();
int global_index = 0;
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
VLOG(3) << "sparse_key_names_[" << i
<< "]: " << sparse_key_names_[table_id][i];
Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]);
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int fea_idx = 0;
// tensor->lod()[0].size() == batch_size + 1
for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) {
// should be skipped feasign defined in protobuf
if (ids[fea_idx] == 0u) {
continue;
}
feature_label[global_index++] =
static_cast<float>(label_ptr[lod_idx - 1]);
}
}
}
CHECK(global_index == feature.size())
<< "expect fea info size:" << feature.size() << " real:" << global_index;
}
void DownpourWorker::FillSparseValue(size_t table_idx) {
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == table_id) {
table = i;
break;
}
}
auto& fea_value = feature_values_[table_id];
auto fea_idx = 0u;
std::vector<float> init_value(table.fea_dim());
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
std::string slot_name = sparse_key_names_[table_id][i];
std::string emb_slot_name = sparse_value_names_[table_id][i];
Variable* var = thread_scope_->FindVar(slot_name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
Variable* var_emb = thread_scope_->FindVar(emb_slot_name);
LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>();
float* ptr = tensor_emb->mutable_data<float>({len, table.emb_dim()},
platform::CPUPlace());
memset(ptr, 0, sizeof(float) * len * table.emb_dim());
auto& tensor_lod = tensor->lod()[0];
LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod);
for (auto index = 0u; index < len; ++index) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
sizeof(float) * table.emb_dim());
continue;
}
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2,
sizeof(float) * table.emb_dim());
fea_idx++;
}
}
}
void DownpourWorker::TrainFilesWithProfiler() {
VLOG(3) << "Begin to train files with profiler";
platform::SetNumThreads(1);
device_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op_name.push_back(op->Type());
}
}
VLOG(3) << "op name size: " << op_name.size();
op_total_time.resize(op_name.size());
for (size_t i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
platform::Timer timeline;
double total_time = 0.0;
double read_time = 0.0;
double pull_sparse_time = 0.0;
double collect_label_time = 0.0;
double fill_sparse_time = 0.0;
double push_sparse_time = 0.0;
double push_dense_time = 0.0;
int cur_batch;
int batch_cnt = 0;
uint64_t total_inst = 0;
timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) {
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
VLOG(3) << "program config size: " << param_.program_config_size();
for (size_t i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
timeline.Start();
fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid,
sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], table.fea_dim());
timeline.Pause();
pull_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
timeline.Start();
CollectLabelInfo(i);
timeline.Pause();
collect_label_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
timeline.Start();
FillSparseValue(i);
timeline.Pause();
fill_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
VLOG(3) << "Fill sparse value for all sparse table done.";
int run_op_idx = 0;
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
timeline.Start();
VLOG(3) << "Going to run op " << op_name[run_op_idx];
op->Run(*thread_scope_, place_);
VLOG(3) << "Op " << op_name[run_op_idx] << " Finished";
timeline.Pause();
op_total_time[run_op_idx++] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
}
if (need_to_push_sparse_) {
for (size_t i = 0;
i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
timeline.Start();
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
timeline.Pause();
push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
}
if (need_to_push_dense_) {
timeline.Start();
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
timeline.Pause();
push_dense_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
VLOG(3) << "push sparse and dense gradient done.";
int32_t tmp_push_dense_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
}
push_dense_status_.resize(0);
}
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
}
if (need_to_push_sparse_) {
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
}
push_sparse_status_.resize(0);
}
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
VLOG(3) << "going to increase thread version";
VLOG(3) << "push dense table id size: "
<< param_.program_config(0).push_dense_table_id_size();
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
}
PrintFetchVars();
thread_scope_->DropKids();
total_inst += cur_batch;
++batch_cnt;
if (thread_id_ == 0) {
// should be configured here
if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < op_total_time.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt);
}
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "pull sparse time percent: %f\n",
pull_sparse_time / total_time * 100);
fprintf(stderr, "collect label time percent: %f\n",
collect_label_time / total_time * 100);
fprintf(stderr, "fill sparse time percent: %f\n",
fill_sparse_time / total_time * 100);
fprintf(stderr, "push sparse time percent: %f\n",
push_sparse_time / total_time * 100);
fprintf(stderr, "push dense time percent: %f\n",
push_dense_time / total_time * 100);
fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
}
}
timeline.Start();
}
}
void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files";
platform::SetNumThreads(1);
device_reader_->Start();
int batch_cnt = 0;
int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) {
// pull sparse here
for (size_t i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid,
sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], table.fea_dim());
CollectLabelInfo(i);
FillSparseValue(i);
}
VLOG(3) << "fill sparse value for all sparse table done.";
// do computation here
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
}
}
if (need_to_push_sparse_) {
// push gradients here
for (size_t i = 0;
i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
}
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
VLOG(3) << "push dense gradient done.";
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t tmp_push_dense_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
}
push_dense_status_.resize(0);
}
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
}
if (need_to_push_sparse_) {
VLOG(3) << "push sparse gradient done.";
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
}
push_sparse_status_.resize(0);
}
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
}
PrintFetchVars();
thread_scope_->DropKids();
++batch_cnt;
}
}
} // end namespace framework
} // end namespace paddle
...@@ -18,14 +18,16 @@ limitations under the License. */ ...@@ -18,14 +18,16 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
...@@ -115,6 +117,35 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -115,6 +117,35 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
} }
} }
void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset,
const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
&trainer_desc);
VLOG(3) << "Going to create trainer, trainer class is "
<< trainer_desc.class_name();
std::shared_ptr<TrainerBase> trainer;
trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
// initialize trainer
VLOG(3) << "Going to initialize trainer";
trainer->Initialize(trainer_desc, dataset);
VLOG(3) << "Set root scope here";
trainer->SetScope(scope);
// prepare training environment and helper environment
VLOG(3) << "Try to init train environment";
trainer->InitTrainerEnv(main_program, place_);
VLOG(3) << "Try to init other environment";
trainer->InitOtherEnv(main_program);
// training and finalize training
VLOG(3) << "Trainer starts to run";
trainer->Run();
VLOG(3) << "Trainer going to finalize";
trainer->Finalize();
return;
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars, bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars, const std::vector<std::string>& skip_ref_cnt_vars,
......
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -110,6 +112,9 @@ class Executor { ...@@ -110,6 +112,9 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset, const std::string& trainer_desc_str);
private: private:
const platform::Place place_; const platform::Place place_;
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor_thread_worker.h" #include "paddle/fluid/framework/executor_thread_worker.h"
#include <algorithm> #include <algorithm>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
...@@ -244,6 +245,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() { ...@@ -244,6 +245,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
SetDevice(); SetDevice();
thread_reader_->Start(); thread_reader_->Start();
std::vector<double> op_total_time; std::vector<double> op_total_time;
std::vector<std::string> op_name; std::vector<std::string> op_name;
for (auto& op : ops_) { for (auto& op : ops_) {
...@@ -273,7 +275,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() { ...@@ -273,7 +275,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
++batch_cnt; ++batch_cnt;
thread_scope_->DropKids(); thread_scope_->DropKids();
if (thread_id_ == 0) { if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % 1000 == 0) { if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < ops_.size(); ++i) { for (size_t i = 0; i < ops_.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i, fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt); op_name[i].c_str(), op_total_time[i] / batch_cnt);
...@@ -283,6 +285,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() { ...@@ -283,6 +285,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
for (int i = 0; i < fetch_var_num; ++i) { for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]); print_fetch_var(thread_scope_, fetch_var_names_[i]);
} }
fprintf(stderr, "IO percent: %f\n", read_time / total_time);
} }
} }
timeline.Start(); timeline.Start();
...@@ -293,7 +296,7 @@ void ExecutorThreadWorker::TrainFiles() { ...@@ -293,7 +296,7 @@ void ExecutorThreadWorker::TrainFiles() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
// todo: configurable // todo: configurable
SetDevice(); // SetDevice();
int fetch_var_num = fetch_var_names_.size(); int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear(); fetch_values_.clear();
...@@ -513,7 +516,6 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) { ...@@ -513,7 +516,6 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) {
auto& push_g = _feature_push_value[table_id]; auto& push_g = _feature_push_value[table_id];
check_pull_push_memory(features, &push_g, fea_dim); check_pull_push_memory(features, &push_g, fea_dim);
collect_feasign_info(table_id); collect_feasign_info(table_id);
} }
......
if(WITH_PSLIB)
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope pslib_brpc pslib)
else()
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope)
endif(WITH_PSLIB)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
#ifdef PADDLE_WITH_PSLIB
template <class AR>
paddle::ps::Archive<AR>& operator<<(paddle::ps::Archive<AR>& ar,
const MultiSlotType& ins) {
ar << ins.GetType();
ar << ins.GetOffset();
ar << ins.GetFloatData();
ar << ins.GetUint64Data();
return ar;
}
template <class AR>
paddle::ps::Archive<AR>& operator>>(paddle::ps::Archive<AR>& ar,
MultiSlotType& ins) {
ar >> ins.MutableType();
ar >> ins.MutableOffset();
ar >> ins.MutableFloatData();
ar >> ins.MutableUint64Data();
return ar;
}
#endif
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
#endif
void FleetWrapper::InitServer(const std::string& dist_desc, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
VLOG(3) << "Going to init server";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_server(dist_desc, index);
is_initialized_ = true;
} else {
VLOG(3) << "Server can be initialized only once";
}
#endif
}
void FleetWrapper::InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list,
int node_num, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
VLOG(3) << "Going to init worker";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_worker(dist_desc,
const_cast<uint64_t*>(host_sign_list.data()),
node_num, index);
is_initialized_ = true;
} else {
VLOG(3) << "Worker can be initialized only once";
}
#endif
}
void FleetWrapper::StopServer() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to stop server";
pslib_ptr_->stop_server();
#endif
}
uint64_t FleetWrapper::RunServer() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to run server";
return pslib_ptr_->run_server();
#else
return 0;
#endif
}
void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to gather server ips";
pslib_ptr_->gather_servers(const_cast<uint64_t*>(host_sign_list.data()),
node_num);
#endif
}
void FleetWrapper::GatherClients(const std::vector<uint64_t>& host_sign_list) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to gather client ips";
size_t len = host_sign_list.size();
pslib_ptr_->gather_clients(const_cast<uint64_t*>(host_sign_list.data()), len);
#endif
}
std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to get client info";
return pslib_ptr_->get_client_info();
#endif
return std::vector<uint64_t>();
}
void FleetWrapper::CreateClient2ClientConnection() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to create client2client connection";
pslib_ptr_->create_client2client_connection();
#endif
}
void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim) {
#ifdef PADDLE_WITH_PSLIB
std::vector<::std::future<int32_t>> pull_sparse_status;
pull_sparse_status.resize(0);
fea_keys->clear();
fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) {
Variable* var = scope.FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) {
continue;
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
auto status = pslib_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
pull_sparse_status.push_back(std::move(status));
for (auto& t : pull_sparse_status) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
exit(-1);
}
}
#endif
}
void FleetWrapper::PullDenseVarsAsync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status) {
#ifdef PADDLE_WITH_PSLIB
auto& regions = _regions[tid];
regions.clear();
regions.resize(var_names.size());
for (auto i = 0u; i < var_names.size(); ++i) {
Variable* var = scope.FindVar(var_names[i]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel());
regions[i] = std::move(reg);
}
auto status =
pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid);
pull_dense_status->push_back(std::move(status));
#endif
}
void FleetWrapper::PullDenseVarsSync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
auto& regions = _regions[tid];
regions.clear();
regions.reserve(var_names.size());
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto status =
pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid);
status.wait();
#endif
}
void FleetWrapper::PushDenseParamSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
auto place = platform::CPUPlace();
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->mutable_data<float>(place);
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
CHECK(status == 0) << "push dense param failed, status[" << status << "]";
#endif
}
void FleetWrapper::PushDenseVarsSync(
Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {}
void FleetWrapper::PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* push_sparse_status) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int count = tensor->numel();
float* g = tensor->data<float>();
paddle::ps::Region reg(g, count);
regions.emplace_back(std::move(reg));
}
auto status = pslib_ptr_->_worker_ptr->push_dense(regions.data(),
regions.size(), table_id);
push_sparse_status->push_back(std::move(status));
#endif
}
void FleetWrapper::PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys, const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status) {
#ifdef PADDLE_WITH_PSLIB
int offset = 2;
uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
exit(-1);
}
float* g = g_tensor->data<float>();
Variable* var = scope.FindVar(sparse_key_names[i]);
CHECK(var != nullptr) << "var[" << sparse_key_names[i] << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
exit(-1);
}
int len = tensor->numel();
int64_t* ids = tensor->data<int64_t>();
push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) {
t.resize(emb_dim + offset);
}
for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) {
g += emb_dim;
continue;
}
CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
(*push_values)[fea_idx][0] = 1.0f;
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
g += emb_dim;
fea_idx++;
}
}
CHECK(fea_idx == fea_keys.size()) << "fea_idx: " << fea_idx
<< "features size: " << fea_keys.size();
std::vector<float*> push_g_vec;
for (auto i = 0u; i < fea_keys.size(); ++i) {
push_g_vec.push_back((*push_values)[i].data());
}
auto status = pslib_ptr_->_worker_ptr->push_sparse(
table_id, fea_keys.data(), (const float**)push_g_vec.data(),
fea_keys.size());
push_sparse_status->push_back(std::move(status));
#endif
}
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pslib_ptr_=" << pslib_ptr_;
VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr;
return pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type,
handler);
#else
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler"
<< " does nothing when no pslib";
#endif
return 0;
}
std::future<int32_t> FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) {
#ifdef PADDLE_WITH_PSLIB
return pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id,
msg);
#else
VLOG(0) << "FleetWrapper::SendClientToClientMsg"
<< " does nothing when no pslib";
#endif
return std::future<int32_t>();
}
template <typename T>
void FleetWrapper::Serialize(const std::vector<T*>& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
for (size_t i = 0; i < t.size(); ++i) {
ar << *(t[i]);
}
*str = std::string(ar.buffer(), ar.length());
#else
VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib";
#endif
}
template <typename T>
void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB
if (str.length() == 0) {
return;
}
paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
if (ar.cursor() == ar.finish()) {
return;
}
while (ar.cursor() < ar.finish()) {
t->push_back(ar.get<T>());
}
CHECK(ar.cursor() == ar.finish());
VLOG(3) << "Deserialize size " << t->size();
#else
VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib";
#endif
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
std::vector<std::vector<MultiSlotType>>*, const std::string&);
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#ifdef PADDLE_WITH_PSLIB
#include <archive.h>
#include <pslib.h>
#endif
#include <atomic>
#include <ctime>
#include <map>
#include <random>
#include <string>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace framework {
// A wrapper class for pslib.h, this class follows Singleton pattern
// i.e. only initialized once in the current process
// Example:
// std::shared_ptr<FleetWrapper> fleet_ptr =
// FleetWrapper::GetInstance();
// string dist_desc;
// fleet_ptr->InitServer(dist_desc, 0);
// interface design principles:
// Pull
// Sync: PullSparseVarsSync
// Async: PullSparseVarsAsync(not implemented currently)
// Push
// Sync: PushSparseVarsSync
// Async: PushSparseVarsAsync(not implemented currently)
// Async: PushSparseVarsWithLabelAsync(with special usage)
// Push dense variables to server in Async mode
// Param<in>: scope, table_id, var_names
// Param<out>: push_sparse_status
class FleetWrapper {
public:
virtual ~FleetWrapper() {}
FleetWrapper() {}
// Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values
void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values,
int fea_dim);
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
void PullDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status);
void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names,
// Param<out>: push_sparse_status
void PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* push_sparse_status);
void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push sparse variables with labels to server in Async mode
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, var_grad_names,
// fea_keys, fea_labels, sparse_grad_names
// Param<out>: push_values, push_sparse_status
void PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys,
const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status);
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
/*
void PushSparseVarsAsync(
const Scope& scope,
const uint64_t table_id,
const std::vector<uint64_t>& fea_keys,
const std::vector<std::string>& sparse_grad_names,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status);
*/
void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, int node_num,
int index);
void StopServer();
uint64_t RunServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
// gather client ip
void GatherClients(const std::vector<uint64_t>& host_sign_list);
// get client info
std::vector<uint64_t> GetClientsInfo();
// create client to client connection
void CreateClient2ClientConnection();
// register client to client communication
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
// send client to client message
std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg);
template <typename T>
void Serialize(const std::vector<T*>& t, std::string* str);
template <typename T>
void Deserialize(std::vector<T>* t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
}
return s_instance_;
}
#ifdef PADDLE_WITH_PSLIB
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif
private:
static std::shared_ptr<FleetWrapper> s_instance_;
#ifdef PADDLE_WITH_PSLIB
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
#endif
protected:
static bool is_initialized_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
namespace paddle {
namespace framework {
void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param();
skip_ops_.resize(param_.skip_ops_size());
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
}
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr);
continue;
}
}
void HogwildWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL(
root_scope_, "root_scope should be set before creating thread scope");
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
} else {
auto* ptr = thread_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
}
}
}
void HogwildWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed =
device_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
device_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
CreateThreadScope(main_prog);
CreateThreadOperators(main_prog);
}
void HogwildWorker::TrainFilesWithProfiler() {
platform::SetNumThreads(1);
device_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
op_name.push_back(op->Type());
}
op_total_time.resize(ops_.size());
for (size_t i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
platform::Timer timeline;
double total_time = 0.0;
double read_time = 0.0;
int cur_batch;
int batch_cnt = 0;
timeline.Start();
uint64_t total_inst = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
VLOG(3) << "read a batch in thread " << thread_id_;
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
for (size_t i = 0; i < ops_.size(); ++i) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
timeline.Start();
VLOG(3) << "Going to run op " << op_name[i];
if (!need_skip) {
ops_[i]->Run(*thread_scope_, place_);
}
VLOG(3) << "Op " << op_name[i] << " Finished";
timeline.Pause();
op_total_time[i] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
total_inst += cur_batch;
++batch_cnt;
PrintFetchVars();
if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < ops_.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt);
}
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
}
}
thread_scope_->DropKids();
timeline.Start();
}
}
void HogwildWorker::TrainFiles() {
platform::SetNumThreads(1);
// how to accumulate fetched values here
device_reader_->Start();
int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) {
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
}
}
PrintFetchVars();
thread_scope_->DropKids();
}
}
void HogwildWorker::PrintFetchVars() {
// call count
batch_num_++;
int batch_per_print = fetch_config_.print_period();
if (thread_id_ == 0) {
if (batch_num_ % batch_per_print == 0) {
int fetch_var_num = fetch_config_.fetch_var_names_size();
for (int i = 0; i < fetch_var_num; ++i) {
platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
fetch_config_.fetch_var_str_format(i));
}
}
}
}
} // end namespace framework
} // end namespace paddle
cc_library(fs SRCS fs.cc DEPS string_helper glog boost)
cc_library(shell SRCS shell.cc DEPS string_helper glog)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/io/fs.h"
#include <memory>
namespace paddle {
namespace framework {
static void fs_add_read_converter_internal(std::string& path, // NOLINT
bool& is_pipe, // NOLINT
const std::string& converter) {
if (converter == "") {
return;
}
if (!is_pipe) {
path = string::format_string("( %s ) < \"%s\"", converter.c_str(),
path.c_str());
is_pipe = true;
} else {
path = string::format_string("%s | %s", path.c_str(), converter.c_str());
}
}
static void fs_add_write_converter_internal(std::string& path, // NOLINT
bool& is_pipe, // NOLINT
const std::string& converter) {
if (converter == "") {
return;
}
if (!is_pipe) {
path = string::format_string("( %s ) > \"%s\"", converter.c_str(),
path.c_str());
is_pipe = true;
} else {
path = string::format_string("%s | %s", converter.c_str(), path.c_str());
}
}
static std::shared_ptr<FILE> fs_open_internal(const std::string& path,
bool is_pipe,
const std::string& mode,
size_t buffer_size,
int* err_no = 0) {
std::shared_ptr<FILE> fp = nullptr;
if (!is_pipe) {
fp = shell_fopen(path, mode);
} else {
fp = shell_popen(path, mode, err_no);
}
if (buffer_size > 0) {
char* buffer = new char[buffer_size];
CHECK_EQ(0, setvbuf(&*fp, buffer, _IOFBF, buffer_size));
fp = {&*fp, [fp, buffer](FILE*) mutable { // NOLINT
CHECK(fp.unique()); // NOLINT
fp = nullptr;
delete[] buffer;
}};
}
return fp;
}
static bool fs_begin_with_internal(const std::string& path,
const std::string& str) {
return strncmp(path.c_str(), str.c_str(), str.length()) == 0;
}
static bool fs_end_with_internal(const std::string& path,
const std::string& str) {
return path.length() >= str.length() &&
strncmp(&path[path.length() - str.length()], str.c_str(),
str.length()) == 0;
}
static size_t& localfs_buffer_size_internal() {
static size_t x = 0;
return x;
}
size_t localfs_buffer_size() { return localfs_buffer_size_internal(); }
void localfs_set_buffer_size(size_t x) { localfs_buffer_size_internal() = x; }
std::shared_ptr<FILE> localfs_open_read(std::string path,
const std::string& converter) {
bool is_pipe = false;
if (fs_end_with_internal(path, ".gz")) {
fs_add_read_converter_internal(path, is_pipe, "zcat");
}
fs_add_read_converter_internal(path, is_pipe, converter);
return fs_open_internal(path, is_pipe, "r", localfs_buffer_size());
}
std::shared_ptr<FILE> localfs_open_write(std::string path,
const std::string& converter) {
shell_execute(
string::format_string("mkdir -p $(dirname \"%s\")", path.c_str()));
bool is_pipe = false;
if (fs_end_with_internal(path, ".gz")) {
fs_add_write_converter_internal(path, is_pipe, "gzip");
}
fs_add_write_converter_internal(path, is_pipe, converter);
return fs_open_internal(path, is_pipe, "w", localfs_buffer_size());
}
int64_t localfs_file_size(const std::string& path) {
struct stat buf;
if (0 != stat(path.c_str(), &buf)) {
LOG(FATAL) << "file stat not zero";
return -1;
}
return (int64_t)buf.st_size;
}
void localfs_remove(const std::string& path) {
if (path == "") {
return;
}
shell_execute(string::format_string("rm -rf %s", path.c_str()));
}
std::vector<std::string> localfs_list(const std::string& path) {
if (path == "") {
return {};
}
std::shared_ptr<FILE> pipe;
int err_no = 0;
pipe = shell_popen(
string::format_string("find %s -type f -maxdepth 1", path.c_str()), "r",
&err_no);
string::LineFileReader reader;
std::vector<std::string> list;
while (reader.getline(&*pipe)) {
list.push_back(reader.get());
}
return list;
}
std::string localfs_tail(const std::string& path) {
if (path == "") {
return "";
}
return shell_get_command_output(
string::format_string("tail -1 %s ", path.c_str()));
}
bool localfs_exists(const std::string& path) {
std::string test_f = shell_get_command_output(
string::format_string("[ -f %s ] ; echo $?", path.c_str()));
if (string::trim_spaces(test_f) == "0") {
return true;
}
std::string test_d = shell_get_command_output(
string::format_string("[ -d %s ] ; echo $?", path.c_str()));
if (string::trim_spaces(test_d) == "0") {
return true;
}
return false;
}
void localfs_mkdir(const std::string& path) {
if (path == "") {
return;
}
shell_execute(string::format_string("mkdir -p %s", path.c_str()));
}
static size_t& hdfs_buffer_size_internal() {
static size_t x = 0;
return x;
}
size_t hdfs_buffer_size() { return hdfs_buffer_size_internal(); }
void hdfs_set_buffer_size(size_t x) { hdfs_buffer_size_internal() = x; }
static std::string& hdfs_command_internal() {
static std::string x = "hadoop fs";
return x;
}
const std::string& hdfs_command() { return hdfs_command_internal(); }
void hdfs_set_command(const std::string& x) { hdfs_command_internal() = x; }
std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no,
const std::string& converter) {
if (fs_end_with_internal(path, ".gz")) {
path = string::format_string("%s -text \"%s\"", hdfs_command().c_str(),
path.c_str());
} else {
path = string::format_string("%s -cat \"%s\"", hdfs_command().c_str(),
path.c_str());
}
bool is_pipe = true;
fs_add_read_converter_internal(path, is_pipe, converter);
return fs_open_internal(path, is_pipe, "r", hdfs_buffer_size(), err_no);
}
std::shared_ptr<FILE> hdfs_open_write(std::string path, int* err_no,
const std::string& converter) {
path = string::format_string("%s -put - \"%s\"", hdfs_command().c_str(),
path.c_str());
bool is_pipe = true;
if (fs_end_with_internal(path, ".gz\"")) {
fs_add_write_converter_internal(path, is_pipe, "gzip");
}
fs_add_write_converter_internal(path, is_pipe, converter);
return fs_open_internal(path, is_pipe, "w", hdfs_buffer_size(), err_no);
}
void hdfs_remove(const std::string& path) {
if (path == "") {
return;
}
shell_execute(string::format_string("%s -rmr %s &>/dev/null; true",
hdfs_command().c_str(), path.c_str()));
}
std::vector<std::string> hdfs_list(const std::string& path) {
if (path == "") {
return {};
}
std::string prefix = "hdfs:";
if (fs_begin_with_internal(path, "afs:")) {
prefix = "afs:";
}
int err_no = 0;
std::vector<std::string> list;
do {
err_no = 0;
std::shared_ptr<FILE> pipe;
pipe = shell_popen(
string::format_string("%s -ls %s | ( grep ^- ; [ $? != 2 ] )",
hdfs_command().c_str(), path.c_str()),
"r", &err_no);
string::LineFileReader reader;
list.clear();
while (reader.getline(&*pipe)) {
std::vector<std::string> line = string::split_string(reader.get());
if (line.size() != 8) {
continue;
}
list.push_back(prefix + line[7]);
}
} while (err_no == -1);
return list;
}
std::string hdfs_tail(const std::string& path) {
if (path == "") {
return "";
}
return shell_get_command_output(string::format_string(
"%s -text %s | tail -1 ", hdfs_command().c_str(), path.c_str()));
}
bool hdfs_exists(const std::string& path) {
std::string test = shell_get_command_output(string::format_string(
"%s -test -e %s ; echo $?", hdfs_command().c_str(), path.c_str()));
if (string::trim_spaces(test) == "0") {
return true;
}
return false;
}
void hdfs_mkdir(const std::string& path) {
if (path == "") {
return;
}
shell_execute(string::format_string("%s -mkdir %s; true",
hdfs_command().c_str(), path.c_str()));
}
int fs_select_internal(const std::string& path) {
if (fs_begin_with_internal(path, "hdfs:")) {
return 1;
} else if (fs_begin_with_internal(path, "afs:")) {
return 1;
}
return 0;
}
std::shared_ptr<FILE> fs_open_read(const std::string& path, int* err_no,
const std::string& converter) {
switch (fs_select_internal(path)) {
case 0:
return localfs_open_read(path, converter);
case 1:
return hdfs_open_read(path, err_no, converter);
default:
LOG(FATAL) << "Not supported";
}
return {};
}
std::shared_ptr<FILE> fs_open_write(const std::string& path, int* err_no,
const std::string& converter) {
switch (fs_select_internal(path)) {
case 0:
return localfs_open_write(path, converter);
case 1:
return hdfs_open_write(path, err_no, converter);
default:
LOG(FATAL) << "Not supported";
}
return {};
}
std::shared_ptr<FILE> fs_open(const std::string& path, const std::string& mode,
int* err_no, const std::string& converter) {
if (mode == "r" || mode == "rb") {
return fs_open_read(path, err_no, converter);
}
if (mode == "w" || mode == "wb") {
return fs_open_write(path, err_no, converter);
}
LOG(FATAL) << "Unknown mode: " << mode;
return {};
}
int64_t fs_file_size(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_file_size(path);
default:
LOG(FATAL) << "Not supported";
}
return 0;
}
void fs_remove(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_remove(path);
case 1:
return hdfs_remove(path);
default:
LOG(FATAL) << "Not supported";
}
}
std::vector<std::string> fs_list(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_list(path);
case 1:
return hdfs_list(path);
default:
LOG(FATAL) << "Not supported";
}
return {};
}
std::string fs_tail(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_tail(path);
case 1:
return hdfs_tail(path);
default:
LOG(FATAL) << "Not supported";
}
return "";
}
bool fs_exists(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_exists(path);
case 1:
return hdfs_exists(path);
default:
LOG(FATAL) << "Not supported";
}
return false;
}
void fs_mkdir(const std::string& path) {
switch (fs_select_internal(path)) {
case 0:
return localfs_mkdir(path);
case 1:
return hdfs_mkdir(path);
default:
LOG(FATAL) << "Not supported";
}
}
} // end namespace framework
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <memory>
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
int fs_select_internal(const std::string& path);
// localfs
extern size_t localfs_buffer_size();
extern void localfs_set_buffer_size(size_t x);
extern std::shared_ptr<FILE> localfs_open_read(std::string path,
const std::string& converter);
extern std::shared_ptr<FILE> localfs_open_write(std::string path,
const std::string& converter);
extern int64_t localfs_file_size(const std::string& path);
extern void localfs_remove(const std::string& path);
extern std::vector<std::string> localfs_list(const std::string& path);
extern std::string localfs_tail(const std::string& path);
extern bool localfs_exists(const std::string& path);
extern void localfs_mkdir(const std::string& path);
// hdfs
extern size_t hdfs_buffer_size();
extern void hdfs_set_buffer_size(size_t x);
extern const std::string& hdfs_command();
extern void hdfs_set_command(const std::string& x);
extern std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no,
const std::string& converter);
extern std::shared_ptr<FILE> hdfs_open_write(std::string path, int* err_no,
const std::string& converter);
extern void hdfs_remove(const std::string& path);
extern std::vector<std::string> hdfs_list(const std::string& path);
extern std::string hdfs_tail(const std::string& path);
extern bool hdfs_exists(const std::string& path);
extern void hdfs_mkdir(const std::string& path);
// aut-detect fs
extern std::shared_ptr<FILE> fs_open_read(const std::string& path, int* err_no,
const std::string& converter);
extern std::shared_ptr<FILE> fs_open_write(const std::string& path, int* err_no,
const std::string& converter);
extern std::shared_ptr<FILE> fs_open(const std::string& path,
const std::string& mode, int* err_no,
const std::string& converter = "");
extern int64_t fs_file_size(const std::string& path);
extern void fs_remove(const std::string& path);
extern std::vector<std::string> fs_list(const std::string& path);
extern std::string fs_tail(const std::string& path);
extern bool fs_exists(const std::string& path);
extern void fs_mkdir(const std::string& path);
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/io/shell.h"
namespace paddle {
namespace framework {
std::shared_ptr<FILE> shell_fopen(const std::string& path,
const std::string& mode) {
#if defined _WIN32 || defined __APPLE__
return nullptr;
#else
if (shell_verbose()) {
LOG(INFO) << "Opening file[" << path << "] with mode[" << mode << "]";
}
FILE* fp;
if (!(fp = fopen(path.c_str(), mode.c_str()))) {
LOG(FATAL) << "fopen fail, path[" << path << "], mode[" << mode << "]";
}
return {fp, [path](FILE* fp) {
if (shell_verbose()) {
LOG(INFO) << "Closing file[" << path << "]";
}
if (0 != fclose(fp)) {
LOG(FATAL) << "fclose fail, path[" << path << "]";
}
}};
#endif
}
// Close all open file descriptors
// The implementation is async signal safe
// Mostly copy from CPython code
static int close_open_fds_internal() {
#if defined _WIN32 || defined __APPLE__
return 0;
#else
struct linux_dirent {
long d_ino = 0; // NOLINT
off_t d_off;
unsigned short d_reclen = 0; // NOLINT
char d_name[256];
};
int dir_fd = -1;
if ((dir_fd = open("/proc/self/fd", O_RDONLY)) < 0) {
LOG(FATAL) << "proc/self/fd open fail";
return -1;
}
char buffer[sizeof(linux_dirent)];
for (;;) {
int bytes = 0;
if ((bytes = syscall(SYS_getdents, dir_fd,
reinterpret_cast<linux_dirent*>(buffer),
sizeof(buffer))) < 0) {
LOG(FATAL) << "syscall fail";
return -1;
}
if (bytes == 0) {
break;
}
linux_dirent* entry = NULL;
for (int offset = 0; offset < bytes; offset += entry->d_reclen) {
entry = reinterpret_cast<linux_dirent*>(buffer + offset);
int fd = 0;
const char* s = entry->d_name;
while (*s >= '0' && *s <= '9') {
fd = fd * 10 + (*s - '0');
s++;
}
if (s != entry->d_name && fd != dir_fd && fd >= 3) {
close(fd);
}
}
}
close(dir_fd);
return 0;
#endif
}
static int shell_popen_fork_internal(const char* real_cmd, bool do_read,
int parent_end, int child_end) {
#if defined _WIN32 || defined __APPLE__
return 0;
#else
int child_pid = -1;
// Too frequent calls to fork() makes openmpi very slow. Use vfork() instead.
// But vfork() is very dangerous. Be careful.
if ((child_pid = vfork()) < 0) {
return -1;
}
// The following code is async signal safe (No memory allocation, no access to
// global data, etc.)
if (child_pid != 0) {
return child_pid;
}
int child_std_end = do_read ? 1 : 0;
close(parent_end);
if (child_end != child_std_end) {
if (dup2(child_end, child_std_end) != child_std_end) {
return -1;
}
close(child_end);
}
close_open_fds_internal();
if (execl("/bin/sh", "sh", "-c", real_cmd, NULL) < 0) {
return -1;
}
exit(127);
#endif
}
std::shared_ptr<FILE> shell_popen(const std::string& cmd,
const std::string& mode, int* err_no) {
#if defined _WIN32 || defined __APPLE__
return nullptr;
#else
bool do_read = mode == "r";
bool do_write = mode == "w";
if (!(do_read || do_write)) {
*err_no = -1;
return NULL;
}
if (shell_verbose()) {
LOG(INFO) << "Opening pipe[" << cmd << "] with mode[" << mode << "]";
}
std::string real_cmd = "set -o pipefail; " + cmd;
int pipe_fds[2];
if (pipe(pipe_fds) != 0) {
*err_no = -1;
return NULL;
}
int parent_end = 0;
int child_end = 0;
if (do_read) {
parent_end = pipe_fds[0];
child_end = pipe_fds[1];
} else if (do_write) {
parent_end = pipe_fds[1];
child_end = pipe_fds[0];
}
int child_pid = shell_popen_fork_internal(real_cmd.c_str(), do_read,
parent_end, child_end);
close(child_end);
fcntl(parent_end, F_SETFD, FD_CLOEXEC);
FILE* fp;
if ((fp = fdopen(parent_end, mode.c_str())) == NULL) {
*err_no = -1;
return NULL;
}
return {fp, [child_pid, cmd, err_no](FILE* fp) {
if (shell_verbose()) {
LOG(INFO) << "Closing pipe[" << cmd << "]";
}
if (fclose(fp) != 0) {
*err_no = -1;
}
int wstatus = -1;
waitpid(child_pid, &wstatus, 0);
if (wstatus == 0 || wstatus == (128 + SIGPIPE) * 256 ||
(wstatus == -1 && errno == ECHILD)) {
} else {
*err_no = -1;
LOG(WARNING) << "status[" << wstatus << "], cmd[" << cmd << "]"
<< ", err_no[" << *err_no << "]";
}
if (wstatus == -1 && errno == ECHILD) {
LOG(WARNING) << "errno is ECHILD";
}
}};
#endif
}
static int shell_p2open_fork_internal(const char* real_cmd, int pipein_fds[2],
int pipeout_fds[2]) {
#if defined _WIN32 || defined __APPLE__
return 0;
#else
int child_pid = -1;
if ((child_pid = fork()) < 0) {
return -1;
}
if (child_pid != 0) {
return child_pid;
}
close(pipein_fds[0]);
close(pipeout_fds[1]);
if (pipein_fds[1] != 1) {
if (dup2(pipein_fds[1], 1) != 1) {
return -1;
}
close(pipein_fds[1]);
}
if (pipeout_fds[0] != 0) {
if (dup2(pipeout_fds[0], 0) != 0) {
return -1;
}
close(pipeout_fds[0]);
}
close_open_fds_internal();
if (execl("/bin/sh", "sh", "-c", real_cmd, NULL) < 0) {
return -1;
}
exit(127);
#endif
}
std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(
const std::string& cmd) {
#if defined _WIN32 || defined __APPLE__
return {};
#else
if (shell_verbose()) {
LOG(INFO) << "Opening bidirectional pipe[" << cmd << "]";
}
std::string real_cmd = "set -o pipefail; " + cmd;
int pipein_fds[2];
int pipeout_fds[2];
if (pipe(pipein_fds) != 0) {
return {NULL, NULL};
}
if (pipe(pipeout_fds) != 0) {
return {NULL, NULL};
}
int child_pid =
shell_p2open_fork_internal(real_cmd.c_str(), pipein_fds, pipeout_fds);
close(pipein_fds[1]);
close(pipeout_fds[0]);
fcntl(pipein_fds[0], F_SETFD, FD_CLOEXEC);
fcntl(pipeout_fds[1], F_SETFD, FD_CLOEXEC);
std::shared_ptr<int> child_life = {
NULL, [child_pid, cmd](void*) {
if (shell_verbose()) {
LOG(INFO) << "Closing bidirectional pipe[" << cmd << "]";
}
int wstatus, ret;
do {
PCHECK((ret = waitpid(child_pid, &wstatus, 0)) >= 0 ||
(ret == -1 && errno == EINTR));
} while (ret == -1 && errno == EINTR);
PCHECK(wstatus == 0 || wstatus == (128 + SIGPIPE) * 256 ||
(wstatus == -1 && errno == ECHILD))
<< "status[" << wstatus << "], cmd[" << cmd << "]";
if (wstatus == -1 && errno == ECHILD) {
LOG(WARNING) << "errno is ECHILD";
}
}};
FILE* in_fp;
PCHECK((in_fp = fdopen(pipein_fds[0], "r")) != NULL);
FILE* out_fp;
PCHECK((out_fp = fdopen(pipeout_fds[1], "w")) != NULL);
return {{in_fp, [child_life](FILE* fp) { PCHECK(fclose(fp) == 0); }},
{out_fp, [child_life](FILE* fp) { PCHECK(fclose(fp) == 0); }}};
#endif
}
std::string shell_get_command_output(const std::string& cmd) {
#if defined _WIN32 || defined __APPLE__
return "";
#else
int err_no = 0;
do {
err_no = 0;
std::shared_ptr<FILE> pipe = shell_popen(cmd, "r", &err_no);
string::LineFileReader reader;
if (reader.getdelim(&*pipe, 0)) {
pipe = nullptr;
if (err_no == 0) {
return reader.get();
}
}
} while (err_no == -1);
return "";
#endif
}
} // end namespace framework
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fcntl.h>
#include <sys/stat.h>
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/syscall.h>
#endif
#include <sys/types.h>
#ifndef _WIN32
#include <sys/wait.h>
#endif
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
inline bool& shell_verbose_internal() {
static bool x = false;
return x;
}
inline bool shell_verbose() { return shell_verbose_internal(); }
inline void shell_set_verbose(bool x) { shell_verbose_internal() = x; }
extern std::shared_ptr<FILE> shell_fopen(const std::string& path,
const std::string& mode);
extern std::shared_ptr<FILE> shell_popen(const std::string& cmd,
const std::string& mode, int* err_no);
extern std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(
const std::string& cmd);
inline void shell_execute(const std::string& cmd) {
int err_no = 0;
do {
err_no = 0;
shell_popen(cmd, "w", &err_no);
} while (err_no == -1);
}
extern std::string shell_get_command_output(const std::string& cmd);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
// get filelist from trainer_desc here
dataset->CreateReaders();
VLOG(3) << "readers created";
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
// change thread num to readers num
thread_num_ = readers.size();
VLOG(3) << "worker thread num: " << thread_num_;
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]);
}
// set debug here
SetDebug(trainer_desc.debug());
}
// call only after all resources are set in current trainer
void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) {
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetPlace(place);
workers_[i]->SetRootScope(root_scope_);
workers_[i]->CreateDeviceResource(main_program); // Program
workers_[i]->BindingDataFeedMemory();
}
}
void MultiTrainer::Run() {
VLOG(3) << "Going to run";
for (int thidx = 0; thidx < thread_num_; ++thidx) {
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
workers_[thidx].get()));
}
}
}
void MultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
dataset_ptr_->DestroyReaders();
root_scope_->DropKids();
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL;
std::mutex PullDenseWorker::mutex_for_version_;
std::map<uint64_t, uint64_t> PullDenseWorker::last_versions_;
std::map<uint64_t, uint64_t> PullDenseWorker::current_version_;
std::map<uint64_t, std::vector<uint64_t>> PullDenseWorker::training_versions_;
std::map<uint64_t, std::vector<std::string>>
PullDenseWorker::dense_value_names_;
void PullDenseWorker::Initialize(const TrainerDesc& param) {
running_ = false;
param_ = param.pull_dense_param();
dwp_param_ = param.downpour_param();
threshold_ = param_.threshold();
thread_num_ = param_.device_num();
sleep_time_ms_ = param_.sleep_time_ms();
for (size_t i = 0;
i < dwp_param_.program_config(0).pull_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
dwp_param_.program_config(0).pull_dense_table_id(i));
TableParameter table;
for (auto i : param_.dense_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
// setup dense variables for each table
int var_num = table.dense_value_name_size();
dense_value_names_[tid].resize(var_num);
for (int j = 0; j < var_num; ++j) {
dense_value_names_[tid][j] = table.dense_value_name(j);
}
// setup training version for each table
training_versions_[tid].resize(thread_num_, 0);
last_versions_[tid] = 0;
current_version_[tid] = 0;
}
fleet_ptr_ = FleetWrapper::GetInstance();
}
void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
for (auto& t : *status_vec) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(WARNING) << "Current Pull Dense Thread Failed Times"
<< ++pull_dense_fail_times_;
}
}
int MAX_FAIL_NUM = 20;
if (pull_dense_fail_times_ > MAX_FAIL_NUM) {
LOG(FATAL) << "Pull Dense Failed Times More Than " << MAX_FAIL_NUM
<< " Times";
exit(-1);
}
status_vec->resize(0);
}
void PullDenseWorker::Stop() {
if (running_) {
running_ = false;
t_.join();
}
}
int PullDenseWorker::Start() {
running_ = true;
t_ = std::thread(&PullDenseWorker::Run, this);
return 0;
}
void PullDenseWorker::Run() {
while (running_) {
pull_dense_status_.resize(0);
for (size_t i = 0;
i < dwp_param_.program_config(0).pull_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
dwp_param_.program_config(0).pull_dense_table_id(i));
if (CheckUpdateParam(tid)) {
fleet_ptr_->PullDenseVarsAsync(
*root_scope_, tid, dense_value_names_[tid], &pull_dense_status_);
ResetThreadVersion(tid);
}
}
if (pull_dense_status_.size() != 0) {
Wait(&pull_dense_status_);
}
#ifndef _WIN32
usleep(sleep_time_ms_ * 1000);
#endif
}
}
void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) {
std::lock_guard<std::mutex> lock(mutex_for_version_);
training_versions_[table_id][thread_id]++;
}
bool PullDenseWorker::CheckUpdateParam(uint64_t table_id) {
std::lock_guard<std::mutex> lock(mutex_for_version_);
auto& version = training_versions_[table_id];
current_version_[table_id] =
*(std::min_element(version.begin(), version.end()));
if (current_version_[table_id] - last_versions_[table_id] < threshold_) {
return false;
}
return true;
}
void PullDenseWorker::ResetThreadVersion(uint64_t table_id) {
std::lock_guard<std::mutex> lock(mutex_for_version_);
last_versions_[table_id] = current_version_[table_id];
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; }
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace framework {
class TrainerBase {
public:
TrainerBase() {}
virtual ~TrainerBase() {}
// model memory are hosted in root_scope
void SetScope(Scope* root_scope);
void SetDebug(const bool debug) { debug_ = debug; }
void SetDataset(Dataset* dataset_ptr) { dataset_ptr_ = dataset_ptr; }
virtual void Initialize(const TrainerDesc& trainer_desc,
Dataset* data_set) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) = 0;
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
virtual void Run() = 0;
virtual void Finalize() = 0;
protected:
Scope* root_scope_;
bool debug_;
Dataset* dataset_ptr_;
};
// general trainer for async execution
// local trainer and distributed trainer are supported
// depends on the assigned device_worker
class MultiTrainer : public TrainerBase {
public:
MultiTrainer() {}
virtual ~MultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {}
virtual void Run();
virtual void Finalize();
protected:
int thread_num_;
std::vector<std::thread> threads_;
std::vector<std::shared_ptr<DataFeed>> readers_;
std::vector<std::shared_ptr<DeviceWorker>> workers_;
};
class DistMultiTrainer : public MultiTrainer {
public:
DistMultiTrainer() {}
virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Run();
virtual void Finalize();
protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
import "data_feed.proto";
package paddle.framework;
message TrainerDesc {
// class name for create trainer desc
// the matchness of trainer name and device worker name
// will be checked in python API
optional string class_name = 1;
// class name for creating device worker
optional string device_worker_name = 2;
// thread number
optional int32 thread_num = 3;
// if we need to binding cpu
optional bool binding_cpu = 4 [ default = false ];
repeated string filelist = 5;
optional bool debug = 6 [ default = false ];
optional FetchConfig fetch_config = 7;
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103;
optional PullDenseWorkerParameter pull_dense_param = 102;
// datafeed desc
optional DataFeedDesc data_desc = 201;
}
message HogwildWorkerParameter { repeated string skip_ops = 1; }
message DownpourWorkerParameter {
repeated TableParameter sparse_table = 1;
repeated TableParameter dense_table = 2;
repeated string skip_ops = 3;
repeated ProgramConfig program_config = 4;
optional bool push_sparse = 5 [ default = true ];
optional bool push_dense = 6 [ default = true ];
}
message FetchConfig {
enum Method { PRINT = 0; }
repeated string fetch_var_names = 1;
repeated string fetch_var_str_format = 2;
optional int32 print_period = 3 [ default = 100 ];
optional Method method = 4 [ default = PRINT ];
}
message ProgramConfig {
required string program_id = 1;
repeated int32 push_sparse_table_id = 2;
repeated int32 push_dense_table_id = 3;
repeated int32 pull_sparse_table_id = 4;
repeated int32 pull_dense_table_id = 5;
}
message PullDenseWorkerParameter {
// dense table only and specialized usage
optional int32 threshold = 1 [ default = 1 ];
optional int32 device_num = 2;
optional int32 sleep_time_ms = 3 [ default = 2 ];
repeated TableParameter dense_table = 4;
}
message TableParameter {
// dense table only
optional int64 table_id = 1;
repeated string dense_value_name = 2;
repeated string dense_grad_name = 3;
repeated int32 push_dense_wait_times = 5;
// sparse table only
repeated string sparse_key_name = 6;
repeated string sparse_value_name = 7;
repeated string sparse_grad_name = 8;
repeated int32 push_sparse_wait_times = 9;
// sparse table only and specialized usage
optional int32 emb_dim = 10;
optional int32 fea_dim = 11;
optional string label_var_name = 12;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/trainer_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<TrainerBase> (*CreatetrainerFunction)();
typedef std::unordered_map<std::string, CreatetrainerFunction> trainerMap;
trainerMap g_trainer_map;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
std::string TrainerFactory::TrainerTypeList() {
std::string trainer_types;
for (auto iter = g_trainer_map.begin(); iter != g_trainer_map.end(); ++iter) {
if (iter != g_trainer_map.begin()) {
trainer_types += ", ";
}
trainer_types += iter->first;
}
return trainer_types;
}
std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
std::string trainer_class) {
if (g_trainer_map.count(trainer_class) < 1) {
LOG(WARNING) << "Trainer class: " << trainer_class << " not defined";
LOG(WARNING) << TrainerTypeList();
exit(-1);
}
return g_trainer_map[trainer_class]();
}
REGISTER_TRAINER_CLASS(MultiTrainer);
REGISTER_TRAINER_CLASS(DistMultiTrainer);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/trainer.h"
namespace paddle {
namespace framework {
class TrainerFactory {
public:
static std::string TrainerTypeList();
static std::shared_ptr<TrainerBase> CreateTrainer(std::string trainer_class);
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/trainer.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
TEST() {
// create multi trainer
// create hogwild device worker
// create dataset
// train for a while
}
}
}
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable* var, proto::VarType::Type var_type) { void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
......
...@@ -18,5 +18,6 @@ limitations under the License. */ ...@@ -18,5 +18,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable *var, proto::VarType::Type var_type); void InitializeVariable(Variable *var, proto::VarType::Type var_type);
}
} } // end namespace framework
} // end namespace paddle
...@@ -93,6 +93,9 @@ nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) ...@@ -93,6 +93,9 @@ nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context)
cc_library(timer SRCS timer.cc) cc_library(timer SRCS timer.cc)
cc_test(timer_test SRCS timer_test.cc DEPS timer) cc_test(timer_test SRCS timer_test.cc DEPS timer)
cc_library(lodtensor_printer SRCS lodtensor_printer.cc DEPS ddim place tensor scope lod_tensor variable_helper framework_proto)
cc_test(lodtensor_printer_test SRCS lodtensor_printer_test.cc DEPS lodtensor_printer)
cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS}) cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS})
if(WITH_GPU) if(WITH_GPU)
nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_tracer gpu_info enforce) nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_tracer gpu_info enforce)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/lodtensor_printer.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace platform {
template <typename T>
void print_lod_tensor(const std::string& var_name,
const framework::LoDTensor& lod_tensor,
const std::string& print_info) {
auto inspect = lod_tensor.data<T>();
auto element_num = lod_tensor.numel();
std::ostringstream sstream;
sstream << print_info << "\t";
sstream << var_name << "\t";
sstream << inspect[0];
for (int j = 1; j < element_num; ++j) {
sstream << " " << inspect[j];
}
std::cout << sstream.str() << std::endl;
}
void PrintVar(framework::Scope* scope, const std::string& var_name,
const std::string& print_info) {
framework::Variable* var = scope->FindVar(var_name);
if (var == nullptr) {
VLOG(1) << "Variable Name " << var_name << " does not exist in your scope";
return;
}
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
if (tensor == nullptr) {
VLOG(1) << "tensor of variable " << var_name
<< " does not exist in your scope";
return;
}
#define PrintLoDTensorCallback(cpp_type, proto_type) \
do { \
if (tensor->type() == proto_type) { \
print_lod_tensor<cpp_type>(var_name, *tensor, print_info); \
return; \
} \
} while (0)
_ForEachDataType_(PrintLoDTensorCallback);
VLOG(1) << "PrintVar: unrecognized data type:" << tensor->type();
}
} // end namespace platform
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace platform {
void PrintVar(framework::Scope* scope, const std::string& var_name,
const std::string& print_info);
} // end namespace platform
} // end namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/lodtensor_printer.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
TEST(LodTensorPrinter, PrintVar) {
paddle::framework::Scope scope;
paddle::platform::PrintVar(&scope, "NotAVar", "We don't have var");
}
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer scope_pool feed_fetch_method pass_builder parallel_executor profiler layer scope_pool
tracer analysis_predictor imperative_profiler) tracer analysis_predictor imperative_profiler)
if(WITH_PYTHON) if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op) list(APPEND PYBIND_DEPS py_func_op)
endif() endif()
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc fleet_wrapper_py.cc data_set_py.cc imperative.cc ir.cc inference_api.cc)
if(WITH_PYTHON) if(WITH_PYTHON)
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#ifdef _XOPEN_SOURCE #ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE #undef _XOPEN_SOURCE
#endif #endif
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <memory>
#include <string>
#include <vector>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/dataset_factory.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/data_set_py.h"
namespace py = pybind11;
namespace pd = paddle::framework;
namespace paddle {
namespace pybind {
void BindDataset(py::module* m) {
py::class_<framework::Dataset, std::shared_ptr<framework::Dataset>>(*m,
"Dataset")
.def(py::init([](const std::string& name = "MultiSlotDataset") {
return framework::DatasetFactory::CreateDataset(name);
}))
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("get_filelist", &framework::Dataset::GetFileList)
.def("get_thread_num", &framework::Dataset::GetThreadNum)
.def("get_trainer_num", &framework::Dataset::GetTrainerNum)
.def("get_hdfs_config", &framework::Dataset::GetHdfsConfig)
.def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc)
.def("register_client2client_msg_handler",
&framework::Dataset::RegisterClientToClientMsgHandler)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("release_memory", &framework::Dataset::ReleaseMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle);
}
} // end namespace pybind
} // end namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindDataset(py::module* m);
} // namespace pybind
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <string>
#include <vector>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
namespace py = pybind11;
namespace pd = paddle::framework;
namespace paddle {
namespace pybind {
void BindFleetWrapper(py::module* m) {
py::class_<framework::FleetWrapper>(*m, "Fleet")
.def(py::init())
.def("push_dense", &framework::FleetWrapper::PushDenseVarsSync)
.def("init_server", &framework::FleetWrapper::InitServer)
.def("run_server", &framework::FleetWrapper::RunServer)
.def("init_worker", &framework::FleetWrapper::InitWorker)
.def("init_model", &framework::FleetWrapper::PushDenseParamSync)
.def("stop_server", &framework::FleetWrapper::StopServer)
.def("gather_servers", &framework::FleetWrapper::GatherServers)
.def("gather_clients", &framework::FleetWrapper::GatherClients)
.def("get_clients_info", &framework::FleetWrapper::GetClientsInfo)
.def("create_client2client_connection",
&framework::FleetWrapper::CreateClient2ClientConnection);
} // end FleetWrapper
} // end namespace pybind
} // end namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindFleetWrapper(py::module* m);
} // namespace pybind
} // namespace paddle
...@@ -50,7 +50,9 @@ limitations under the License. */ ...@@ -50,7 +50,9 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/async_executor_py.h" #include "paddle/fluid/pybind/async_executor_py.h"
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h" #include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/ir.h"
...@@ -59,7 +61,6 @@ limitations under the License. */ ...@@ -59,7 +61,6 @@ limitations under the License. */
#include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/reader_py.h"
#include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/recordio.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/string/to_string.h" #include "paddle/fluid/string/to_string.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -922,6 +923,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -922,6 +923,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
.def("close", &Executor::Close) .def("close", &Executor::Close)
.def("run_from_dataset", &Executor::RunFromDataset)
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars, int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) { const std::vector<std::string> &fetch_vars) {
...@@ -1356,9 +1358,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1356,9 +1358,11 @@ All parameter, weight, gradient are variables in Paddle.
BindRecordIOWriter(&m); BindRecordIOWriter(&m);
BindAsyncExecutor(&m); BindAsyncExecutor(&m);
BindFleetWrapper(&m);
BindGraph(&m); BindGraph(&m);
BindNode(&m); BindNode(&m);
BindInferenceApi(&m); BindInferenceApi(&m);
BindDataset(&m);
} }
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
cc_library(stringpiece SRCS piece.cc) cc_library(stringpiece SRCS piece.cc)
cc_library(pretty_log SRCS pretty_log.cc) cc_library(pretty_log SRCS pretty_log.cc)
cc_library(string_helper SRCS string_helper.cc DEPS boost)
cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags) cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags)
cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags) cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags)
cc_test(to_string_test SRCS to_string_test.cc) cc_test(to_string_test SRCS to_string_test.cc)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/string/string_helper.h"
#include <ctype.h>
#include <stdio.h>
#include <cstring>
#include <string>
#include <vector>
#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
namespace paddle {
namespace string {
inline size_t count_spaces(const char* s) {
size_t count = 0;
while (*s != 0 && isspace(*s++)) {
count++;
}
return count;
}
inline size_t count_nonspaces(const char* s) {
size_t count = 0;
while (*s != 0 && !isspace(*s++)) {
count++;
}
return count;
}
// remove leading and tailing spaces
std::string trim_spaces(const std::string& str) {
const char* p = str.c_str();
while (*p != 0 && isspace(*p)) {
p++;
}
size_t len = strlen(p);
while (len > 0 && isspace(p[len - 1])) {
len--;
}
return std::string(p, len);
}
inline int str_to_float(const char* str, float* v) {
const char* head = str;
char* cursor = NULL;
int index = 0;
while (*(head += count_spaces(head)) != 0) {
v[index++] = std::strtof(head, &cursor);
if (head == cursor) {
break;
}
head = cursor;
}
return index;
}
// A helper class for reading lines from file.
// A line buffer is maintained. It
// doesn't need to know the maximum possible length of a line.
char* LineFileReader::getdelim(FILE* f, char delim) {
#ifndef _WIN32
int32_t ret = ::getdelim(&_buffer, &_buf_size, delim, f);
if (ret >= 0) {
if (ret >= 1 && _buffer[ret - 1] == delim) {
_buffer[--ret] = 0;
}
_length = (size_t)ret;
return _buffer;
} else {
_length = 0;
CHECK(feof(f));
return NULL;
}
#else
return NULL;
#endif
}
} // end namespace string
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ctype.h>
#include <stdio.h>
#include <cstring>
#include <string>
#include <utility>
#include <vector>
#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
namespace paddle {
namespace string {
inline size_t count_spaces(const char* s);
inline size_t count_nonspaces(const char* s);
template <class... ARGS>
void format_string_append(std::string& str, const char* fmt, // NOLINT
ARGS&&... args) {
int len = snprintf(NULL, 0, fmt, args...);
CHECK_GE(len, 0);
size_t oldlen = str.length();
str.resize(oldlen + len + 1);
CHECK(snprintf(&str[oldlen], (size_t)len + 1, fmt, args...) == len);
str.resize(oldlen + len);
}
template <class... ARGS>
void format_string_append(std::string& str, const std::string& fmt, // NOLINT
ARGS&&... args) {
format_string_append(str, fmt.c_str(), args...);
}
template <class... ARGS>
std::string format_string(const char* fmt, ARGS&&... args) {
std::string str;
format_string_append(str, fmt, args...);
return std::move(str);
}
template <class... ARGS>
std::string format_string(const std::string& fmt, ARGS&&... args) {
return format_string(fmt.c_str(), args...);
}
// remove leading and tailing spaces
std::string trim_spaces(const std::string& str);
int str_to_float(const char* str, float* v);
// split string by delim
template <class T = std::string>
std::vector<T> split_string(const std::string& str, const std::string& delim) {
size_t pre_pos = 0;
size_t pos = 0;
std::string tmp_str;
std::vector<T> res_list;
res_list.clear();
if (str.empty()) {
return res_list;
}
while ((pos = str.find(delim, pre_pos)) != std::string::npos) {
tmp_str.assign(str, pre_pos, pos - pre_pos);
res_list.push_back(tmp_str);
pre_pos = pos + 1;
}
tmp_str.assign(str, pre_pos, str.length() - pre_pos);
if (!tmp_str.empty()) {
res_list.push_back(tmp_str);
}
return res_list;
}
// split string by spaces. Leading and tailing spaces are ignored. Consecutive
// spaces are treated as one delim.
template <class T = std::string>
std::vector<T> split_string(const std::string& str) {
std::vector<T> list;
const char* p;
int pre_pos = 0;
int pos = 0;
std::string tmp_str;
if (str.empty()) {
return list;
}
for (p = str.c_str(); *p != 0;) {
if (!isspace(*p)) {
pos = pre_pos;
p++;
while (*p != 0 && !isspace(*p)) {
pos++;
p++;
}
tmp_str.assign(str, pre_pos, pos - pre_pos + 1);
list.push_back(tmp_str);
pre_pos = pos + 1;
} else {
pre_pos++;
p++;
}
}
return list;
}
template <class T>
std::string join_strings(const std::vector<T>& strs, char delim) {
std::string str;
for (size_t i = 0; i < strs.size(); i++) {
if (i > 0) {
str += delim;
}
str += boost::lexical_cast<std::string>(strs[i]);
}
return str;
}
// A helper class for reading lines from file. A line buffer is maintained. It
// doesn't need to know the maximum possible length of a line.
class LineFileReader {
public:
LineFileReader() {}
LineFileReader(LineFileReader&&) = delete;
LineFileReader(const LineFileReader&) = delete;
~LineFileReader() { ::free(_buffer); }
char* getline(FILE* f) { return this->getdelim(f, '\n'); }
char* getdelim(FILE* f, char delim);
char* get() { return _buffer; }
size_t length() { return _length; }
private:
char* _buffer = NULL;
size_t _buf_size = 0;
size_t _length = 0;
};
} // end namespace string
} // end namespace paddle
...@@ -24,10 +24,13 @@ from .executor import * ...@@ -24,10 +24,13 @@ from .executor import *
from . import data_feed_desc from . import data_feed_desc
from .data_feed_desc import * from .data_feed_desc import *
from . import dataset
from .dataset import *
from . import async_executor from . import async_executor
from .async_executor import * from .async_executor import *
from . import trainer from . import trainer_desc
from . import inferencer from . import inferencer
from . import io from . import io
...@@ -43,10 +46,13 @@ from . import regularizer ...@@ -43,10 +46,13 @@ from . import regularizer
from . import average from . import average
from . import metrics from . import metrics
from . import transpiler from . import transpiler
from . import incubate
from . import distribute_lookup_table from . import distribute_lookup_table
from .param_attr import ParamAttr, WeightNormParamAttr from .param_attr import ParamAttr, WeightNormParamAttr
from .data_feeder import DataFeeder from .data_feeder import DataFeeder
from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope, _Scope from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope, _Scope
from .incubate import fleet
from .incubate import data_generator
from .transpiler import DistributeTranspiler, \ from .transpiler import DistributeTranspiler, \
memory_optimize, release_memory, DistributeTranspilerConfig memory_optimize, release_memory, DistributeTranspilerConfig
from .lod_tensor import create_lod_tensor, create_random_int_lodtensor from .lod_tensor import create_lod_tensor, create_random_int_lodtensor
...@@ -64,9 +70,9 @@ from . import install_check ...@@ -64,9 +70,9 @@ from . import install_check
Tensor = LoDTensor Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \ __all__ = framework.__all__ + executor.__all__ + \
trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \ trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + \ parallel_executor.__all__ + lod_tensor.__all__ + \
data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [ data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',
......
...@@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2 ...@@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
from . import io from . import io
from .data_feed_desc import DataFeedDesc from .data_feed_desc import DataFeedDesc
from .trainer_desc import TrainerDesc, MultiTrainer, DistMultiTrainer
from .distributed import ps_instance from .distributed import ps_instance
from .contrib.utils import hdfs_utils as hdfs from .contrib.utils import hdfs_utils as hdfs
...@@ -77,6 +78,17 @@ class AsyncExecutor(object): ...@@ -77,6 +78,17 @@ class AsyncExecutor(object):
""" """
def __init__(self, place=None, run_mode=""): def __init__(self, place=None, run_mode=""):
"""
Init.
Example:
>>> place = fluid.CPUPlace()
>>> async_executor = fluid.AsyncExecutor(place)
Args:
place(Place): CPUPlace only
run_mode(str): default is empty string.
"""
if place is None: if place is None:
place = core.CPUPlace() place = core.CPUPlace()
if not isinstance(place, core.CPUPlace): if not isinstance(place, core.CPUPlace):
...@@ -159,7 +171,8 @@ class AsyncExecutor(object): ...@@ -159,7 +171,8 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, mode, debug) fetch_var_names, mode, debug,
str(id(program_desc)))
def download_data(self, def download_data(self,
afs_path, afs_path,
...@@ -172,18 +185,19 @@ class AsyncExecutor(object): ...@@ -172,18 +185,19 @@ class AsyncExecutor(object):
""" """
download_data is a default download method for distributed training download_data is a default download method for distributed training
a user download data without this method a user download data without this method
Example: Example:
>>> exe = fluid.AsyncExecutor() >>> exe = fluid.AsyncExecutor()
>>> exe.download_data("/xxx/xxx/xx/", >>> exe.download_data("/xxx/xxx/xx/",
>>> "./data", "afs:// >>> "./data", "afs://
>>> xxx.xxx.xxx.xxx:9901", "xxx,yyy") >>> xxx.xxx.xxx.xxx:9901", "xxx,yyy")
Args: Args:
afs_path(str): afs_path defined by users afs_path(str): afs_path defined by users
local_path(str): download data path local_path(str): download data path
fs_default_name(str): file system server address fs_default_name(str): file system server address
ugi(str): hadoop ugi ugi(str): hadoop ugi
file_cn(int): a user can specify file number for debugging file_cnt(int): a user can specify file number for debugging
hadoop_home(str): hadoop home path hadoop_home(str): hadoop home path
process_num(int): download process num process_num(int): download process num
""" """
...@@ -217,7 +231,7 @@ class AsyncExecutor(object): ...@@ -217,7 +231,7 @@ class AsyncExecutor(object):
def config_distributed_nodes(self): def config_distributed_nodes(self):
""" """
if a user needs to run distributed async executor if a user needs to run distributed async executor
he or she needs to do a global configuration so that he or she needs to do a global configuration so that
information of current process can be obtained information of current process can be obtained
""" """
self.instance = ps_instance.PaddlePSInstance(1, 2) self.instance = ps_instance.PaddlePSInstance(1, 2)
...@@ -241,16 +255,19 @@ class AsyncExecutor(object): ...@@ -241,16 +255,19 @@ class AsyncExecutor(object):
def init_server(self, dist_desc): def init_server(self, dist_desc):
""" """
initialize server of current node if current process is a server Initialize server of current node if current process is a server.
Args: Args:
dist_desc(str): a protobuf string that describes dist_desc(str): a protobuf string that describes
how to init a worker and a server how to init a worker and a server
""" """
if self.instance is None: if self.instance is None:
raise ValueError( raise ValueError(
'instance is None, please run config_distributed_nodes init instance' 'instance is None, please run config_distributed_nodes init instance'
) )
self.executor.init_server(dist_desc, self.instance._rankid) self.dist_desc_str = text_format.MessageToString(dist_desc)
self.dist_desc = dist_desc
self.executor.init_server(self.dist_desc_str, self.instance._rankid)
ip = self.executor.start_server() ip = self.executor.start_server()
self.instance.set_ip(ip) self.instance.set_ip(ip)
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
...@@ -260,23 +277,31 @@ class AsyncExecutor(object): ...@@ -260,23 +277,31 @@ class AsyncExecutor(object):
def init_worker(self, dist_desc, startup_program): def init_worker(self, dist_desc, startup_program):
""" """
initialize worker of current node if current process is a worker Initialize worker of current node if current process is a worker.
Args: Args:
dist_desc(str): a protobuf string that describes dist_desc(str): a protobuf string that describes
how to init a worker and a server how to init a worker and a server
startup_program(fluid.Program): startup program of current process startup_program(fluid.Program): startup program of current process
""" """
if self.instance is None: if self.instance is None:
raise ValueError( raise ValueError(
'instance is None, please run config_distributed_nodes init instance' 'instance is None, please run config_distributed_nodes init instance'
) )
self.dist_desc_str = text_format.MessageToString(dist_desc)
self.dist_desc = dist_desc
place = core.CPUPlace() place = core.CPUPlace()
executor = Executor(place) executor = Executor(place)
executor.run(startup_program) if isinstance(startup_program, list):
for sp in startup_program:
executor.run(sp)
else:
executor.run(startup_program)
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips() ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.executor.init_worker(self.dist_desc_str, ips,
self.instance.get_node_cnt(), self.instance.get_node_cnt(),
self.instance._rankid) self.instance._rankid)
self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait all worker start
...@@ -298,9 +323,10 @@ class AsyncExecutor(object): ...@@ -298,9 +323,10 @@ class AsyncExecutor(object):
def save_model(self, save_path): def save_model(self, save_path):
""" """
save_model command that can be invoked from one of the worker save_model command that can be invoked from one of the worker
model parameters are saved in servers and upload to save_path of file system model parameters are saved in servers and upload to save_path of file system.
Args: Args:
save_path(str): save path to file system save_path(str): save path to file system
""" """
if self.instance is None: if self.instance is None:
raise ValueError( raise ValueError(
......
...@@ -68,6 +68,7 @@ class DataFeedDesc(object): ...@@ -68,6 +68,7 @@ class DataFeedDesc(object):
def __init__(self, proto_file): def __init__(self, proto_file):
self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat"
with open(proto_file, 'r') as f: with open(proto_file, 'r') as f:
text_format.Parse(f.read(), self.proto_desc) text_format.Parse(f.read(), self.proto_desc)
if self.proto_desc.name == "MultiSlotDataFeed": if self.proto_desc.name == "MultiSlotDataFeed":
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
__all__ = ['DatasetFactory']
class DatasetFactory(object):
"""
DatasetFactory is a factory which create dataset by its name,
you can create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset".
Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
"""
def __init__(self):
"""
Init
"""
pass
def create_dataset(self, datafeed_class="QueueDataset"):
"""
Create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset".
"""
try:
dataset = globals()[datafeed_class]()
return dataset
except:
raise ValueError("datafeed class %s does not exist" %
datafeed_class)
class DatasetBase(object):
"""
Base dataset class
"""
def __init__(self):
"""
Init
"""
# define class name here
# to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat"
self.dataset = core.Dataset("MultiSlotDataset")
self.thread_num = 0
def set_pipe_command(self, pipe_command):
"""
Set pipe command of current dataset
A pipe command is a UNIX pipeline command that can be used only
Example:
>>> dataset.set_pipe_command("python my_script.py")
Args:
pipe_command: pipe command
"""
self.proto_desc.pipe_command = pipe_command
def set_batch_size(self, batch_size):
"""
Set batch size. Will be effective during training
Example:
>>> dataset.set_batch_size(128)
Args:
batch_size: batch size
"""
self.proto_desc.batch_size = batch_size
def set_thread(self, thread_num):
"""
Set thread num, it is the num of readers.
Example:
>>> dataset.set_thread(12)
Args:
thread_num: thread num
"""
self.dataset.set_thread_num(thread_num)
self.thread_num = thread_num
def set_filelist(self, filelist):
"""
Set file list in current worker.
Example:
>>> dataset.set_filelist(['a.txt', 'b.txt'])
Args:
filelist: file list
"""
self.dataset.set_filelist(filelist)
def set_use_var(self, var_list):
"""
Set Variables which you will use.
Example:
>>> dataset.set_use_var([data, label])
Args:
var_list: variable list
"""
multi_slot = self.proto_desc.multi_slot_desc
for var in var_list:
slot_var = multi_slot.slots.add()
slot_var.is_used = True
slot_var.name = var.name
if var.lod_level == 0:
slot_var.is_dense = True
if var.dtype == core.VarDesc.VarType.FP32:
slot_var.type = "float"
elif var.dtype == core.VarDesc.VarType.INT64:
slot_var.type = "uint64"
else:
raise ValueError(
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
)
def set_hdfs_config(self, fs_name, fs_ugi):
"""
Set hdfs config: fs name ad ugi
Example:
>>> dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
Args:
fs_name: fs name
fs_ugi: fs ugi
"""
self.dataset.set_hdfs_config(fs_name, fs_ugi)
def _prepare_to_run(self):
"""
Set data_feed_desc before load or shuffle,
user no need to call this function.
"""
self.dataset.set_data_feed_desc(self.desc())
def desc(self):
"""
Returns a protobuf message for this DataFeedDesc
Example:
>>> print(dataset.desc())
Returns:
A string message
"""
return text_format.MessageToString(self.proto_desc)
class InMemoryDataset(DatasetBase):
"""
InMemoryDataset, it will load data into memory
and shuffle data before training
Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
"""
def __init__(self):
"""
Init
"""
super(InMemoryDataset, self).__init__()
self.proto_desc.name = "MultiSlotInMemoryDataFeed"
def load_into_memory(self):
"""
Load data into memory
Example:
>>> import paddle.fluid as fluid
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist)
>>> dataset.load_into_memory()
"""
self._prepare_to_run()
self.dataset.load_into_memory()
def local_shuffle(self):
"""
Local shuffle
Example:
>>> import paddle.fluid as fluid
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist)
>>> dataset.local_shuffle()
"""
self.dataset.local_shuffle()
def global_shuffle(self, fleet=None):
"""
Global shuffle.
Global shuffle can be used only in distributed mode. i.e. multiple
processes on single machine or multiple machines training together.
If you run in distributed mode, you should pass fleet instead of None.
Examples:
>>> import paddle.fluid as fluid
>>> import paddle.fluid.incubate.fleet.parameter_server as fleet
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist)
>>> dataset.global_shuffle(fleet)
Args:
fleet: fleet singleton. Default None.
"""
trainer_num = 1
if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker()
trainer_num = fleet.worker_num()
self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num)
if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker()
self.dataset.global_shuffle()
if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker()
class QueueDataset(DatasetBase):
"""
QueueDataset, it will process data streamly.
Example:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory.create_dataset("QueueDataset")
"""
def __init__(self):
"""
Init
"""
super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed"
def local_shuffle(self):
"""
Local shuffle
QueueDataset does not support local shuffle
"""
raise NotImplementedError(
"QueueDataset does not support local shuffle, "
"please use InMemoryDataset for local_shuffle")
def global_shuffle(self, fleet=None):
"""
Global shuffle
"""
raise NotImplementedError(
"QueueDataset does not support global shuffle, "
"please use InMemoryDataset for global_shuffle")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class DeviceWorker(object):
"""
DeviceWorker is an abstract class, which generates worker desc.
This class is an inner class that we do computation logics within
the implementation. For example, execution of a program or a graph.
"""
def __init__(self):
"""
Init.
"""
self.program_ = None
self.infer_ = None
def _set_infer(self, infer=False):
"""
set inference flag for current device worker
Args:
infer(bool): whether to do inference
"""
self.infer_ = infer
def _set_fleet_desc(self, fleet_desc):
"""
Set fleet desc.
Args:
fleet_desc(PSParameter): pslib.PSParameter object
"""
self.fleet_desc_ = fleet_desc
def _set_program(self, program):
"""
Set program.
Args:
program(Program): a Program object
"""
self.program_ = program
def _gen_worker_desc(self, trainer_desc):
"""
Generator worker desc.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
raise NotImplementedError(
"DeviceWorker does not implement gen_worker_desc, "
"please use Hogwild or DownpourSGD, etc.")
class Hogwild(DeviceWorker):
"""
Hogwild is a kind of SGD algorithm.
"""
def __init__(self):
"""
Init.
"""
super(Hogwild, self).__init__()
def _gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is HogwildWorker.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
trainer_desc.device_worker_name = "HogwildWorker"
if self.infer_:
# just ignore feed op for inference model
trainer_desc.hogwild_param.skip_ops.extend(["feed"])
class DownpourSGD(DeviceWorker):
"""
DownpourSGD is a kind of distributed SGD algorithm.
"""
def __init__(self):
"""
Init.
initialize downpourSGD device worker
"""
super(DownpourSGD, self).__init__()
def _gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is DownpourWorker.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
dense_table_set = set()
program_id = str(id(self.program_))
if self.program_ == None:
print("program of current device worker is not configured")
exit(-1)
opt_info = self.program_._fleet_opt
program_configs = opt_info["program_configs"]
downpour = trainer_desc.downpour_param
for pid in program_configs:
if pid == program_id:
pc = downpour.program_config.add()
pc.program_id = program_id
for i in program_configs[program_id]["push_sparse"]:
pc.push_sparse_table_id.extend([i])
for i in program_configs[program_id]["push_dense"]:
pc.push_dense_table_id.extend([i])
dense_table_set.add(i)
for i in program_configs[program_id]["pull_sparse"]:
pc.pull_sparse_table_id.extend([i])
for i in program_configs[program_id]["pull_dense"]:
pc.pull_dense_table_id.extend([i])
dense_table_set.add(i)
break
trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num
for i in self.fleet_desc_.trainer_param.dense_table:
if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.table_id = \
i.table_id
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
self.fleet_desc_.trainer_param.sparse_table[0].table_id
sparse_table.sparse_key_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_key)
sparse_table.sparse_value_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_value)
sparse_table.sparse_grad_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_gradient)
sparse_table.emb_dim = \
self.fleet_desc_.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2
# TODO(guru4elephant): hard code here, need to improve
sparse_table.label_var_name = "click"
for i in self.fleet_desc_.trainer_param.dense_table:
if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name)
downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)
if self.infer_:
downpour.push_dense = False
downpour.push_sparse = False
class DeviceWorkerFactory(object):
def _create_device_worker(self, worker_type):
classname = worker_type.capitalize()
return globals()[classname]()
...@@ -33,6 +33,9 @@ class DownpourSGD(object): ...@@ -33,6 +33,9 @@ class DownpourSGD(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
opt = fluid.DistributedOptimizer(sgd_opt)
opt.minimize()
downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2) downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2)
downpour_sgd.minimize(cost) downpour_sgd.minimize(cost)
""" """
...@@ -43,9 +46,13 @@ class DownpourSGD(object): ...@@ -43,9 +46,13 @@ class DownpourSGD(object):
self.learning_rate_ = learning_rate self.learning_rate_ = learning_rate
self.window_ = window self.window_ = window
self.type = "downpour" self.type = "downpour"
self.data_norm_name = [
".batch_size", ".batch_square_sum", ".batch_sum",
".batch_size@GRAD", ".batch_square_sum@GRAD", ".batch_sum@GRAD"
]
def minimize(self, def minimize(self,
loss, losses,
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
...@@ -65,41 +72,97 @@ class DownpourSGD(object): ...@@ -65,41 +72,97 @@ class DownpourSGD(object):
worker_skipped_ops: operator names that need worker_skipped_ops: operator names that need
to be skipped during execution to be skipped during execution
""" """
params_grads = sorted( if not isinstance(losses, list):
append_backward(loss, parameter_list, no_grad_set), raise ValueError('losses is a list, just lick [model.cost]')
key=lambda x: x[0].name) table_name = find_distributed_lookup_table(losses[0].block.program)
table_name = find_distributed_lookup_table(loss.block.program)
prefetch_slots = find_distributed_lookup_table_inputs( prefetch_slots = find_distributed_lookup_table_inputs(
loss.block.program, table_name) losses[0].block.program, table_name)
prefetch_slots_emb = find_distributed_lookup_table_outputs( prefetch_slots_emb = find_distributed_lookup_table_outputs(
loss.block.program, table_name) losses[0].block.program, table_name)
ps_param = pslib.PSParameter()
server = DownpourServer() server = DownpourServer()
# window is communication strategy
worker = DownpourWorker(self.window_) worker = DownpourWorker(self.window_)
# Todo(guru4elephant): support multiple tables definitions
# currently support one big sparse table
sparse_table_index = 0 sparse_table_index = 0
# currently merge all dense parameters into one dense table
dense_table_index = 1
params = []
grads = []
for i in params_grads:
params.append(i[0])
for i in params_grads:
grads.append(i[1])
server.add_sparse_table(sparse_table_index, self.learning_rate_, server.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
server.add_dense_table(dense_table_index, self.learning_rate_, params,
grads)
worker.add_sparse_table(sparse_table_index, self.learning_rate_, worker.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
worker.add_dense_table(dense_table_index, self.learning_rate_, params, dense_table_index = 1
grads) program_configs = []
ps_param = pslib.PSParameter() param_grads_list = []
for loss_index in range(len(losses)):
program_config = ps_param.trainer_param.program_config.add()
program_config.program_id = str(
id(losses[loss_index].block.program))
program_config.pull_sparse_table_id.extend([sparse_table_index])
program_config.push_sparse_table_id.extend([sparse_table_index])
params_grads = sorted(
append_backward(losses[loss_index], parameter_list,
no_grad_set),
key=lambda x: x[0].name)
param_grads_list.append(params_grads)
params = []
grads = []
data_norm_params = []
data_norm_grads = []
for i in params_grads:
is_data_norm_data = False
for data_norm_name in self.data_norm_name:
if i[0].name.endswith(data_norm_name):
is_data_norm_data = True
data_norm_params.append(i[0])
if not is_data_norm_data:
params.append(i[0])
for i in params_grads:
is_data_norm_data = False
for data_norm_grad in self.data_norm_name:
if i[0].name.endswith(data_norm_grad):
is_data_norm_data = True
data_norm_grads.append(i[1])
if not is_data_norm_data:
grads.append(i[1])
server.add_dense_table(dense_table_index, self.learning_rate_,
params, grads)
worker.add_dense_table(dense_table_index, self.learning_rate_,
params, grads)
program_config.pull_dense_table_id.extend([dense_table_index])
program_config.push_dense_table_id.extend([dense_table_index])
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
dense_table_index += 1
server.add_data_norm_table(dense_table_index,
self.learning_rate_,
data_norm_params, data_norm_grads)
worker.add_dense_table(dense_table_index, self.learning_rate_,
data_norm_params, data_norm_grads)
program_config.pull_dense_table_id.extend([dense_table_index])
program_config.push_dense_table_id.extend([dense_table_index])
dense_table_index += 1
program_configs.append(program_config)
ps_param.server_param.CopyFrom(server.get_desc()) ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc()) ps_param.trainer_param.CopyFrom(worker.get_desc())
for program_config in program_configs:
ps_param.trainer_param.program_config.extend([program_config])
# Todo(guru4elephant): figure out how to support more sparse parameters # Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table # currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"] worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops) ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
return [ps_param, worker_skipped_ops]
# all fleet operations should be defined in operators in the future
# we want to return an object here containing:
# 1) worker execution strategy
# 2) pserver execution strategy
# 3) fleet configurations
# 4) skipped operators in runtime
# 5) distributed optimization
opt_info = {}
opt_info["trainer"] = "DistMultiTrainer"
opt_info["device_worker"] = "DownpourSGD"
opt_info["optimizer"] = "DownpourSGD"
opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops
for loss in losses:
loss.block.program._fleet_opt = opt_info
return None, param_grads_list
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import sys
from .. import core
from . import ps_instance
__all__ = ['Fleet']
class Fleet(object):
"""
"""
def __init__(self):
self.instance_ = ps_instance.PaddlePSInstance()
self.fleet_ = core.FleetWrapper()
def stop(self):
self.instance_.barrier_worker()
if self.instance.is_first_worker():
self.fleet_.stop_server()
self.instance_.barrier_worker()
self.instance_.barrier_all()
self.instance.finalize()
def init_pserver(self, opt_info):
if "fleet_desc" in opt_info:
self.dist_desc_str_ = text_format.MessageToString(opt_info[
"fleet_desc"])
self.dist_desc_ = opt_info["fleet_desc"]
else:
print(
"You should run distributed optimization to get opt_info first")
sys.exit(-1)
self.fleet_.init_server(self.dist_desc_str_)
ip = self.fleet_.start_server()
self.instance_.set_ip(ip)
self.instance.barrier_all()
ips = self.instance.gather_ips()
self.fleet.gather_servers(ips, self.instance_.get_node_cnt())
self.instance_.barrier_all()
def init_worker(self, opt_info):
if "fleet_desc" in opt_info:
self.dist_desc_str_ = text_format.MessageToString(opt_info[
"fleet_desc"])
self.dist_desc_ = opt_info["fleet_desc"]
else:
print(
"You should run distributed optimization to get opt_info first")
sys.exit(-1)
self.instance_.barrier_all()
ips = self.instance.gather_ips()
self.fleet_.init_worker(self.dist_desc_str_, ips,
self.instance_.get_node_cnt(),
self.instance._rankid)
self.instance.barrier_worker()
def init_pserver_model(self):
if self.instance_.is_first_worker():
self.fleet_.init_model()
self.instance_.barrier_worker()
def save_pserver_model(self, save_path):
self.fleet_.save_model(save_path)
...@@ -121,6 +121,18 @@ class PaddlePSInstance(object): ...@@ -121,6 +121,18 @@ class PaddlePSInstance(object):
""" """
return self._nodes return self._nodes
def get_worker_num(self):
"""
Return worker num
"""
return self._worker_num
def get_server_num(self):
"""
Return server num
"""
return self._server_num
def barrier_all(self): def barrier_all(self):
""" """
barrier workers and servers barrier workers and servers
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License.
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: ps.proto # source: ps.proto
...@@ -30,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -30,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle', package='paddle',
syntax='proto2', syntax='proto2',
serialized_pb=_b( serialized_pb=_b(
'\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xce\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xbf\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x12\n\nshared_num\x18\x03 \x01(\x04\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xce\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01' '\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xbf\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x12\n\nshared_num\x18\x03 \x01(\x04\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xce\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -47,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor( ...@@ -47,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=3286, serialized_start=3489,
serialized_end=3338, ) serialized_end=3541, )
_sym_db.RegisterEnumDescriptor(_TABLETYPE) _sym_db.RegisterEnumDescriptor(_TABLETYPE)
TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE) TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE)
...@@ -132,8 +134,8 @@ _PSCMDID = _descriptor.EnumDescriptor( ...@@ -132,8 +134,8 @@ _PSCMDID = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=3341, serialized_start=3544,
serialized_end=3658, ) serialized_end=3861, )
_sym_db.RegisterEnumDescriptor(_PSCMDID) _sym_db.RegisterEnumDescriptor(_PSCMDID)
PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID) PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID)
...@@ -166,8 +168,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( ...@@ -166,8 +168,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=3254, serialized_start=3457,
serialized_end=3284, ) serialized_end=3487, )
_sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE) _sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE)
_PSPARAMETER = _descriptor.Descriptor( _PSPARAMETER = _descriptor.Descriptor(
...@@ -493,6 +495,22 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor( ...@@ -493,6 +495,22 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor(
is_extension=False, is_extension=False,
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor(
name='program_config',
full_name='paddle.DownpourTrainerParameter.program_config',
index=5,
number=6,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
...@@ -503,7 +521,106 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor( ...@@ -503,7 +521,106 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=557, serialized_start=557,
serialized_end=763, ) serialized_end=810, )
_PROGRAMCONFIG = _descriptor.Descriptor(
name='ProgramConfig',
full_name='paddle.ProgramConfig',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='program_id',
full_name='paddle.ProgramConfig.program_id',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='push_sparse_table_id',
full_name='paddle.ProgramConfig.push_sparse_table_id',
index=1,
number=2,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='push_dense_table_id',
full_name='paddle.ProgramConfig.push_dense_table_id',
index=2,
number=3,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='pull_sparse_table_id',
full_name='paddle.ProgramConfig.pull_sparse_table_id',
index=3,
number=4,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='pull_dense_table_id',
full_name='paddle.ProgramConfig.pull_dense_table_id',
index=4,
number=5,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=813,
serialized_end=966, )
_DENSETABLEPARAMETER = _descriptor.Descriptor( _DENSETABLEPARAMETER = _descriptor.Descriptor(
name='DenseTableParameter', name='DenseTableParameter',
...@@ -585,8 +702,8 @@ _DENSETABLEPARAMETER = _descriptor.Descriptor( ...@@ -585,8 +702,8 @@ _DENSETABLEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=765, serialized_start=968,
serialized_end=888, ) serialized_end=1091, )
_SPARSETABLEPARAMETER = _descriptor.Descriptor( _SPARSETABLEPARAMETER = _descriptor.Descriptor(
name='SparseTableParameter', name='SparseTableParameter',
...@@ -684,8 +801,8 @@ _SPARSETABLEPARAMETER = _descriptor.Descriptor( ...@@ -684,8 +801,8 @@ _SPARSETABLEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=890, serialized_start=1093,
serialized_end=1012, ) serialized_end=1215, )
_DOWNPOURSERVERPARAMETER = _descriptor.Descriptor( _DOWNPOURSERVERPARAMETER = _descriptor.Descriptor(
name='DownpourServerParameter', name='DownpourServerParameter',
...@@ -735,8 +852,8 @@ _DOWNPOURSERVERPARAMETER = _descriptor.Descriptor( ...@@ -735,8 +852,8 @@ _DOWNPOURSERVERPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1015, serialized_start=1218,
serialized_end=1149, ) serialized_end=1352, )
_SERVERSERVICEPARAMETER = _descriptor.Descriptor( _SERVERSERVICEPARAMETER = _descriptor.Descriptor(
name='ServerServiceParameter', name='ServerServiceParameter',
...@@ -834,8 +951,8 @@ _SERVERSERVICEPARAMETER = _descriptor.Descriptor( ...@@ -834,8 +951,8 @@ _SERVERSERVICEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1152, serialized_start=1355,
serialized_end=1367, ) serialized_end=1570, )
_TABLEPARAMETER = _descriptor.Descriptor( _TABLEPARAMETER = _descriptor.Descriptor(
name='TableParameter', name='TableParameter',
...@@ -949,8 +1066,8 @@ _TABLEPARAMETER = _descriptor.Descriptor( ...@@ -949,8 +1066,8 @@ _TABLEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1370, serialized_start=1573,
serialized_end=1561, ) serialized_end=1764, )
_TABLEACCESSORPARAMETER = _descriptor.Descriptor( _TABLEACCESSORPARAMETER = _descriptor.Descriptor(
name='TableAccessorParameter', name='TableAccessorParameter',
...@@ -1096,8 +1213,8 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor( ...@@ -1096,8 +1213,8 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1564, serialized_start=1767,
serialized_end=1933, ) serialized_end=2136, )
_DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
name='DownpourTableAccessorParameter', name='DownpourTableAccessorParameter',
...@@ -1227,8 +1344,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( ...@@ -1227,8 +1344,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1936, serialized_start=2139,
serialized_end=2142, ) serialized_end=2345, )
_TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
name='TableAccessorSaveParameter', name='TableAccessorSaveParameter',
...@@ -1294,8 +1411,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( ...@@ -1294,8 +1411,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2144, serialized_start=2347,
serialized_end=2227, ) serialized_end=2430, )
_PSREQUESTMESSAGE = _descriptor.Descriptor( _PSREQUESTMESSAGE = _descriptor.Descriptor(
name='PsRequestMessage', name='PsRequestMessage',
...@@ -1393,8 +1510,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor( ...@@ -1393,8 +1510,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2229, serialized_start=2432,
serialized_end=2330, ) serialized_end=2533, )
_SPARSESGDRULEPARAMETER = _descriptor.Descriptor( _SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
name='SparseSGDRuleParameter', name='SparseSGDRuleParameter',
...@@ -1476,8 +1593,8 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( ...@@ -1476,8 +1593,8 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2332, serialized_start=2535,
serialized_end=2451, ) serialized_end=2654, )
_DENSESGDRULEPARAMETER = _descriptor.Descriptor( _DENSESGDRULEPARAMETER = _descriptor.Descriptor(
name='DenseSGDRuleParameter', name='DenseSGDRuleParameter',
...@@ -1575,8 +1692,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor( ...@@ -1575,8 +1692,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2454, serialized_start=2657,
serialized_end=2679, ) serialized_end=2882, )
_ADAMSGDPARAMETER = _descriptor.Descriptor( _ADAMSGDPARAMETER = _descriptor.Descriptor(
name='AdamSGDParameter', name='AdamSGDParameter',
...@@ -1674,8 +1791,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor( ...@@ -1674,8 +1791,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2682, serialized_start=2885,
serialized_end=2816, ) serialized_end=3019, )
_NAIVESGDPARAMETER = _descriptor.Descriptor( _NAIVESGDPARAMETER = _descriptor.Descriptor(
name='NaiveSGDParameter', name='NaiveSGDParameter',
...@@ -1725,8 +1842,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor( ...@@ -1725,8 +1842,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2818, serialized_start=3021,
serialized_end=2884, ) serialized_end=3087, )
_SUMMARYSGDPARAMETER = _descriptor.Descriptor( _SUMMARYSGDPARAMETER = _descriptor.Descriptor(
name='SummarySGDParameter', name='SummarySGDParameter',
...@@ -1760,8 +1877,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor( ...@@ -1760,8 +1877,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2886, serialized_start=3089,
serialized_end=2945, ) serialized_end=3148, )
_MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
name='MovingAverageRuleParameter', name='MovingAverageRuleParameter',
...@@ -1795,8 +1912,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( ...@@ -1795,8 +1912,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2947, serialized_start=3150,
serialized_end=2993, ) serialized_end=3196, )
_PSRESPONSEMESSAGE = _descriptor.Descriptor( _PSRESPONSEMESSAGE = _descriptor.Descriptor(
name='PsResponseMessage', name='PsResponseMessage',
...@@ -1862,8 +1979,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor( ...@@ -1862,8 +1979,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2995, serialized_start=3198,
serialized_end=3068, ) serialized_end=3271, )
_FSCLIENTPARAMETER = _descriptor.Descriptor( _FSCLIENTPARAMETER = _descriptor.Descriptor(
name='FsClientParameter', name='FsClientParameter',
...@@ -1993,8 +2110,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor( ...@@ -1993,8 +2110,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=3071, serialized_start=3274,
serialized_end=3284, ) serialized_end=3487, )
_PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER _PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER
_PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER _PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER
...@@ -2011,6 +2128,8 @@ _DOWNPOURTRAINERPARAMETER.fields_by_name[ ...@@ -2011,6 +2128,8 @@ _DOWNPOURTRAINERPARAMETER.fields_by_name[
'dense_table'].message_type = _DENSETABLEPARAMETER 'dense_table'].message_type = _DENSETABLEPARAMETER
_DOWNPOURTRAINERPARAMETER.fields_by_name[ _DOWNPOURTRAINERPARAMETER.fields_by_name[
'sparse_table'].message_type = _SPARSETABLEPARAMETER 'sparse_table'].message_type = _SPARSETABLEPARAMETER
_DOWNPOURTRAINERPARAMETER.fields_by_name[
'program_config'].message_type = _PROGRAMCONFIG
_DOWNPOURSERVERPARAMETER.fields_by_name[ _DOWNPOURSERVERPARAMETER.fields_by_name[
'downpour_table_param'].message_type = _TABLEPARAMETER 'downpour_table_param'].message_type = _TABLEPARAMETER
_DOWNPOURSERVERPARAMETER.fields_by_name[ _DOWNPOURSERVERPARAMETER.fields_by_name[
...@@ -2042,6 +2161,7 @@ DESCRIPTOR.message_types_by_name[ ...@@ -2042,6 +2161,7 @@ DESCRIPTOR.message_types_by_name[
'DownpourWorkerParameter'] = _DOWNPOURWORKERPARAMETER 'DownpourWorkerParameter'] = _DOWNPOURWORKERPARAMETER
DESCRIPTOR.message_types_by_name[ DESCRIPTOR.message_types_by_name[
'DownpourTrainerParameter'] = _DOWNPOURTRAINERPARAMETER 'DownpourTrainerParameter'] = _DOWNPOURTRAINERPARAMETER
DESCRIPTOR.message_types_by_name['ProgramConfig'] = _PROGRAMCONFIG
DESCRIPTOR.message_types_by_name['DenseTableParameter'] = _DENSETABLEPARAMETER DESCRIPTOR.message_types_by_name['DenseTableParameter'] = _DENSETABLEPARAMETER
DESCRIPTOR.message_types_by_name['SparseTableParameter'] = _SPARSETABLEPARAMETER DESCRIPTOR.message_types_by_name['SparseTableParameter'] = _SPARSETABLEPARAMETER
DESCRIPTOR.message_types_by_name[ DESCRIPTOR.message_types_by_name[
...@@ -2120,6 +2240,16 @@ DownpourTrainerParameter = _reflection.GeneratedProtocolMessageType( ...@@ -2120,6 +2240,16 @@ DownpourTrainerParameter = _reflection.GeneratedProtocolMessageType(
)) ))
_sym_db.RegisterMessage(DownpourTrainerParameter) _sym_db.RegisterMessage(DownpourTrainerParameter)
ProgramConfig = _reflection.GeneratedProtocolMessageType(
'ProgramConfig',
(_message.Message, ),
dict(
DESCRIPTOR=_PROGRAMCONFIG,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.ProgramConfig)
))
_sym_db.RegisterMessage(ProgramConfig)
DenseTableParameter = _reflection.GeneratedProtocolMessageType( DenseTableParameter = _reflection.GeneratedProtocolMessageType(
'DenseTableParameter', 'DenseTableParameter',
(_message.Message, ), (_message.Message, ),
......
...@@ -23,6 +23,7 @@ from .framework import Program, default_main_program, Variable ...@@ -23,6 +23,7 @@ from .framework import Program, default_main_program, Variable
from . import core from . import core
from . import compiler from . import compiler
from .. import compat as cpt from .. import compat as cpt
from .trainer_factory import TrainerFactory
__all__ = ['Executor', 'global_scope', 'scope_guard'] __all__ = ['Executor', 'global_scope', 'scope_guard']
...@@ -610,3 +611,209 @@ class Executor(object): ...@@ -610,3 +611,209 @@ class Executor(object):
def _run_inference(self, exe, feed): def _run_inference(self, exe, feed):
return exe.run(feed) return exe.run(feed)
def _dump_debug_info(self, program=None, trainer=None):
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(trainer._desc())
if program._fleet_opt:
with open("fleet_desc.prototxt", "w") as fout:
fout.write(str(program._fleet_opt["fleet_desc"]))
def _prepare_trainer(self,
program=None,
dataset=None,
scope=None,
thread=0,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
if scope is None:
scope = global_scope()
if fetch_list is None:
fetch_list = []
if fetch_info is None:
fetch_info = []
assert len(fetch_list) == len(fetch_info)
compiled = isinstance(program, compiler.CompiledProgram)
if not compiled:
trainer = TrainerFactory()._create_trainer(program._fleet_opt)
trainer._set_program(program)
else:
trainer = TrainerFactory()._create_trainer(
program.program._fleet_opt)
trainer._set_program(program.program)
if thread <= 0:
if dataset.thread_num <= 0:
raise RuntimeError(
"You should set thread num first, either in Dataset"
"or in Executor.train_from_dataset")
else:
trainer._set_thread(dataset.thread_num)
else:
trainer._set_thread(thread)
trainer._set_debug(debug)
trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return scope, trainer
def infer_from_dataset(self,
program=None,
dataset=None,
scope=None,
thread=0,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
"""
The document of infer_from_dataset is almost the same as
train_from_dataset, except that in distributed training,
push gradients will be disabled in infer_from_dataset.
infer_from_dataset() can be used for evaluation in multi-thread
very easily.
Args:
program(Program|CompiledProgram): the program that needs to be run,
if not provided, then default_main_program (not compiled) will be used.
dataset(paddle.fluid.Dataset): dataset created outside this function,
a user should provide a well-defined dataset before calling this function.
Please check the document of Dataset if needed. default is None
scope(Scope): the scope used to run this program, you can switch it to different scope
for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread) if thread > 0, default is 0
debug(bool): whether a user wants to run infer_from_dataset, default is False
fetch_list(Variable List): fetch variable list, each variable
will be printed during training, default is None
fetch_info(String List): print information for each variable, default is None
print_period(int): the number of mini-batches for each print, default is 100
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x = fluid.layers.data(name="x", type="int64")
y = fluid.layers.data(name="y", type="int64")
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y])
filelist = ["dataA.txt", "dataB.txt"]
dataset.set_filelist(filelist)
exe.run(fluid.default_startup_program())
exe.infer_from_dataset(program=fluid.default_main_program(),
dataset=dataset)
"""
if dataset == None:
raise RuntimeError("dataset is needed and should be initialized")
if self.place == paddle.fluid.CUDAPlace():
raise RuntimeError("infer_from_dataset is verified on CPUPlace"
"We will open CUDAPlace in the future")
scope, trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
thread=thread,
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer._set_infer(True)
trainer._gen_trainer_desc()
dataset._prepare_to_run()
if debug:
self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())
return None
def train_from_dataset(self,
program=None,
dataset=None,
scope=None,
thread=0,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
"""
Train from a pre-defined Dataset. Dataset is defined in paddle.fluid.dataset.
Given a program, either a program or compiled program, train_from_dataset will
consume all data samples in dataset. Input scope can be given by users. By default,
scope is global_scope(). The total number of thread run in training is `thread`.
Thread number used in training will be minimum value of threadnum in Dataset and
the value of thread in this interface. Debug can be set so that executor will display
Run-Time for all operators and the throughputs of current training task.
Note: train_from_dataset will destroy all resources created within executor for each run.
Args:
program(Program|CompiledProgram): the program that needs to be run,
if not provided, then default_main_program (not compiled) will be used.
dataset(paddle.fluid.Dataset): dataset created outside this function,
a user should provide a well-defined dataset before calling this function.
Please check the document of Dataset if needed.
scope(Scope): the scope used to run this program, you can switch it to different scope
for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread)
debug(bool): whether a user wants to run train_from_dataset
fetch_list(Variable List): fetch variable list, each variable
will be printed during training
fetch_info(String List): print information for each variable
print_period(int): the number of mini-batches for each print
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x = fluid.layers.data(name="x", type="int64")
y = fluid.layers.data(name="y", type="int64")
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y])
dataset.set_thread(2)
filelist = ["dataA.txt", "dataB.txt"]
dataset.set_filelist(filelist)
exe.run(fluid.default_startup_program())
exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset)
"""
if dataset == None:
raise RuntimeError("dataset is need and should be initialized")
if self.place == paddle.fluid.CUDAPlace():
raise RuntimeError("train_from_dataset is verified on CPUPlace"
"We will open CUDAPlace in the future")
scope, trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
thread=thread,
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer._gen_trainer_desc()
dataset._prepare_to_run()
if debug:
self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())
return None
...@@ -2715,6 +2715,11 @@ class Program(object): ...@@ -2715,6 +2715,11 @@ class Program(object):
# whether the program is optimized by memory_optimize_transpiler # whether the program is optimized by memory_optimize_transpiler
self.__is_mem_optimized = False self.__is_mem_optimized = False
# if this program has been optimized by distributed optimizer
# fleet_opt will be given a value
self._fleet_opt = None
self._program_config = None
@property @property
def _is_mem_optimized(self): def _is_mem_optimized(self):
# if the program is optimized, operator input/outputs # if the program is optimized, operator input/outputs
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# incubate directory is mainly for internal use
# after we have tested incubate APIs in industrial application for a period
# we will move stable functions into fluid
__version__ = '0.1.0'
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__all__ = ['MultiSlotDataGenerator']
class DataGenerator(object):
"""
DataGenerator is a general Base class for user to inherit
A user who wants to define his/her own python processing logic
with paddle.fluid.dataset should inherit this class.
"""
def __init__(self):
self._proto_info = None
self.batch_size_ = 32
def _set_line_limit(self, line_limit):
if not isinstance(line_limit, int):
raise ValueError("line_limit%s must be in int type" %
type(line_limit))
if line_limit < 1:
raise ValueError("line_limit can not less than 1")
self._line_limit = line_limit
def set_batch(self, batch_size):
'''
Set batch size of current DataGenerator
This is necessary only if a user wants to define generator_batch
Example:
.. code-block:: python
import paddle.fluid.incubate.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", int_words)
return local_iter
def generate_batch(self, samples):
def local_iter():
for s in samples:
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
self.batch_size_ = batch_size
def run_from_memory(self):
'''
This function generator data from memory, it is usually used for
debug and benchmarking
Example:
.. code-block:: python
import paddle.fluid.incubate.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
yield ("words", [1, 2, 3, 4])
return local_iter
mydata = MyData()
mydata.run_from_memory()
'''
batch_samples = []
line_iter = self.generate_sample(None)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_samples.append(user_parsed_line)
if len(batch_samples) == self.batch_size_:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
batch_samples = []
if len(batch_samples) > 0:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
def run_from_stdin(self):
'''
This function reads the data row from stdin, parses it with the
process function, and further parses the return value of the
process function with the _gen_str function. The parsed data will
be wrote to stdout and the corresponding protofile will be
generated.
Example:
.. code-block:: python
import paddle.fluid.incubate.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", [int_words])
return local_iter
mydata = MyData()
mydata.run_from_stdin()
'''
batch_samples = []
for line in sys.stdin:
line_iter = self.generate_sample(line)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_samples.append(user_parsed_line)
if len(batch_samples) == self.batch_size_:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
batch_samples = []
if len(batch_samples) > 0:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the datafeed,and
updating proto_info infomation.
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the datafeed.
'''
raise NotImplementedError(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator")
def generate_sample(self, line):
'''
This function needs to be overridden by the user to process the
original data row into a list or tuple.
Args:
line(str): the original data row
Returns:
Returns the data processed by the user.
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Note:
The type of feasigns must be in int or float. Once the float
element appears in the feasign, the type of that slot will be
processed into a float.
Example:
.. code-block:: python
import paddle.fluid.incubate.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", [int_words])
return local_iter
'''
raise NotImplementedError(
"Please rewrite this function to return a list or tuple: " +
"[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)")
def generate_batch(self, samples):
'''
This function needs to be overridden by the user to process the
generated samples from generate_sample(self, str) function
It is usually used as batch processing when a user wants to
do preprocessing on a batch of samples, e.g. padding according to
the max length of a sample in the batch
Args:
samples(list tuple): generated sample from generate_sample
Returns:
a python generator, the same format as return value of generate_sample
Example:
.. code-block:: python
import paddle.fluid.incubate.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", int_words)
return local_iter
def generate_batch(self, samples):
def local_iter():
for s in samples:
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
def local_iter():
for sample in samples:
yield sample
return local_iter
class MultiSlotDataGenerator(DataGenerator):
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info infomation.
The input line will be in this format:
>>> [(name, [feasign, ...]), ...]
>>> or ((name, [feasign, ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]
For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
the output will be:
>>> 3 1234 2345 3456 1 1
the proto_info will be:
>>> [("words", "uint64"), ("label", "uint64")]
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type")
output = ""
if self._proto_info is None:
self._proto_info = []
for item in line:
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
self._proto_info.append((name, "uint64"))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if isinstance(elem, float):
self._proto_info[-1] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float" %
type(elem))
output += " " + str(elem)
else:
if len(line) != len(self._proto_info):
raise ValueError(
"the complete field set of two given line are inconsistent.")
for index, item in enumerate(line):
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
if name != self._proto_info[index][0]:
raise ValueError(
"the field name of two given line are not match: require<%s>, get<%d>."
% (self._proto_info[index][0], name))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if self._proto_info[index][1] != "float":
if isinstance(elem, float):
self._proto_info[index] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float"
% type(elem))
output += " " + str(elem)
return output + "\n"
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from __init__ import *
class SyntheticData(MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
sd = SyntheticData()
sd.run_from_memory()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -10,7 +10,5 @@ ...@@ -10,7 +10,5 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: Trainer is moved into fluid.contrib.trainer. __version__ = '0.1.0'
__all__ = []
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
class RoleMakerBase(object):
"""
RoleMakerBase is a base class for assigning a role to current process
in distributed training.
A paddle developer can implement RoleMakerBase to design a role maker
for worker or pserver assignment.
"""
def __init__(self):
self.role_maker_name_ = ""
self.trainer_endpoints_ = []
self.pserver_endpoints_ = []
self.role_is_generated_ = False
def _is_worker(self):
"""
return is_worker() of current process
"""
raise NotImplementedError("Please implement this method in child class")
def _is_server(self):
"""
return is_server() of current process
"""
raise NotImplementedError("Please implement this method in child class")
def _get_local_ip(self):
"""
return get local ip
"""
import socket
self.ip_ = socket.gethostbyname(socket.gethostname())
return self.ip_
def _get_trainer_endpoints(self):
"""
return trainer endpoints
"""
return self.trainer_endpoints_
def _get_pserver_endpoints(self):
"""
return pserver endpoints
"""
return self.pserver_endpoints_
def _generate_role(self):
"""
generate_role() should be called to identify current process's role
"""
raise NotImplementedError("Please implement this method in child class")
class MPIRoleMaker(RoleMakerBase):
"""
MPIRoleMaker is a MPI-API based role maker which is a counter-part of K8SRoleMaker
mpi4py will be used if a developer inherits MPIRoleMaker
"""
def __init__(self):
super(MPIRoleMaker, self).__init__()
from mpi4py import MPI
self.comm_ = MPI.COMM_WORLD
self.MPI = MPI
self.ips_ = None
def _get_rank(self):
"""
return rank
"""
self.rank_ = self.comm_.Get_rank()
return self.rank_
def _get_size(self):
"""
return size
"""
self.size_ = self.comm_.Get_size()
return self.size_
def _all_gather(self, obj):
"""
all_gather(obj) will call MPI's allgather function
"""
self._barrier_all()
return self.comm_.allgather(obj)
def _worker_gather(self, obj):
"""
worker_gather(obj) will call MPI's allgather function
"""
if self._is_worker():
self.node_type_comm_.barrier()
return self.node_type_comm_.allgather(obj)
return None
def _barrier_all(self):
"""
barrier_all() will call MPI's barrier_all function
"""
self.comm_.barrier()
def _get_ips(self):
"""
collect current distributed job's ip list
"""
if self.ips_ == None:
self.ips_ = self.comm_.allgather(self._get_local_ip())
return self.ips_
def _finalize(self):
"""
finalize the current MPI instance.
"""
self.comm_.finalize()
class MPISymetricRoleMaker(MPIRoleMaker):
"""
MPISymetricRoleMaker is designed for worker and server assignment
under MPI. Typically, a worker and a server node will be appointed
on each physical node. This role maker can be only used under MPI.
"""
def __init__(self):
super(MPISymetricRoleMaker, self).__init__()
self.node_type_ = None
self.proc_per_node_ = 2
def _check_role_generation(self):
if not self.role_is_generated_:
sys.stderr.write("generate_role() should be called first")
sys.exit(-1)
return False
return True
def _is_first_worker(self):
"""
return whether current process is the first worker assigned by role maker
"""
if self._check_role_generation():
return self._is_worker() and 0 == self._worker_index()
return False
def _is_worker(self):
"""
return whether current process is worker assigned by role maker
"""
if self._check_role_generation():
return self.node_type_ == 1
return False
def _is_server(self):
"""
return whether current process is server assigned by role maker
"""
if self._check_role_generation():
return self.node_type_ == 0
return False
def _worker_num(self):
"""
return the current number of worker
"""
if self._check_role_generation():
if self._is_worker():
return self._get_size() / 2
return 0
def _server_num(self):
"""
return the current number of server
"""
if self._check_role_generation():
if self._is_server():
return self._get_size() / 2
return 0
def _worker_index(self):
"""
return the index of worker
"""
if self._check_role_generation():
return self.rank_ / self.proc_per_node_
return 0
def _server_index(self):
"""
return the index of server
"""
if self._check_role_generation():
return self.rank_ / self.proc_per_node_
return 0
def _barrier_worker(self):
"""
barrier all workers in current distributed job
"""
if self._check_role_generation():
if self._is_worker():
self.node_type_comm_.barrier()
def _barrier_server(self):
"""
barrier all servers in current distributed job
"""
if self._check_role_generation():
if self._is_server():
self.node_type_comm_.barrier()
def _generate_role(self):
"""
generate currently process's role
"""
if not self.role_is_generated_:
# TODO(guru4elephant): only allow to be called once
self.trainer_endpoints_ = self._get_ips()
self.pserver_endpoints_ = self._get_ips()
if 0 == self._get_rank() % self.proc_per_node_ % 2:
self.node_type_ = 0
else:
self.node_type_ = 1
self.node_type_comm_ = self.comm_.Split(self.node_type_)
self.role_is_generated_ = True
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import sys
import os
from ..base.role_maker import MPISymetricRoleMaker
from .optimizer_factory import *
from google.protobuf import text_format
import paddle.fluid.optimizer as local_optimizer
import paddle.fluid as fluid
class Fleet(object):
"""
Fleet in Python. Fleet is used in distributed training. It is designed as a singlton instance
in c++. A Fleet() object will be initialized automatically when a user import this package as
fleet. The General interface Fleet supports are:
init(): which should be called only once in user's python scripts. init() will initialize
FleetWrapper in CPP, it will also initialize a RoleMaker which is used for identifying
current node's role, e.g. worker, server, etc.
stop(): will be called after a user finishes his/her training task. Fleet instance will be
destroyed when stop() is called.
init_pserver(): will be called by user. When a user knows current process is_worker(), he/she
should call init_pserver() to initialize global information about parameter server
init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect
worker with pserver.
get_worker_num(): return the number of current task's worker node
get_server_num(): return the number of current task's pserver node
is_worker(): return whether current process is a worker
is_server(): return thether current process is a server
init_pserver_model(): initialize model parameters in pserver, called from a worker node
save_pserver_model(): save model parameters in pserver, called from a server node
Example:
.. code-block:: python
import paddle.fluid.incubate.fleet.parameter_server as fleet
from my_model import bow_net
model = bow_net()
fleet.init()
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.0001)
sgd_optimizer = fleet.DistributedOptimizer(sgd_optimizer)
sgd_optimizer.minimize(model.loss)
exe = paddle.fluid.Executor(paddle.fluid.CPUPlace())
if fleet.is_worker():
exe.run(paddle.fluid.default_startup_program())
fleet.init_worker() # init worker should be called before training
# do other things like training
elif fleet.is_server():
fleet.init_pserver()
fleet.stop()
"""
def __init__(self):
self._opt_info = None # for fleet only
self.role_maker_ = None
self.local_ip_ = 0
self.is_initialized_ = False
def init(self):
# TODO(guru4elephant)
# this is a temporary solution
# we will support more configurable RoleMaker for users in the future
"""
init(): which should be called only once in user's python scripts. init() will initialize
FleetWrapper in CPP, it will also initialize a RoleMaker which is used for identifying
current node's role, e.g. worker, server, etc.
"""
if not self.is_initialized_:
self.role_maker_ = MPISymetricRoleMaker()
self.role_maker_._generate_role()
self._fleet_ptr = fluid.core.Fleet()
self.is_initialized_ = True
def stop(self):
"""
stop(): will be called after a user finishes his/her training task. Fleet instance will be
destroyed when stop() is called.
"""
self.role_maker_._barrier_worker()
if self.role_maker_._is_first_worker():
self._fleet_ptr.stop_server()
self.role_maker_._barrier_worker()
self.role_maker_._barrier_all()
self.role_maker_._finalize()
def init_pserver(self):
"""
init_pserver(): will be called by user. When a user knows current process is_worker(), he/she
should call init_pserver() to initialize global information about parameter server
"""
if self._opt_info:
if "fleet_desc" in self._opt_info:
self._dist_desc_str = text_format.MessageToString(
self._opt_info["fleet_desc"])
self._dist_desc = self._opt_info["fleet_desc"]
else:
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
self._fleet_ptr.init_server(self._dist_desc_str,
self.role_maker_._get_rank())
self.local_ip_ = self._fleet_ptr.run_server()
# barrier_all for init_server
self.role_maker_._barrier_all()
self.all_ips_ = self.role_maker_._all_gather(self.local_ip_)
self._fleet_ptr.gather_servers(self.all_ips_,
self.role_maker_._get_size())
# barrier_all for init_worker, wait all workers start
self.role_maker_._barrier_all()
else:
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
def init_worker(self, programs):
"""
init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect
worker with pserver.
Args:
programs(Program|list): a Program or a list of Programs
"""
if not isinstance(programs, list):
programs = [programs]
if self._opt_info:
if "fleet_desc" in self._opt_info:
self._dist_desc_str = text_format.MessageToString(
self._opt_info["fleet_desc"])
self._dist_desc = self._opt_info["fleet_desc"]
else:
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
# barrier_all for init_server, wait for server starts
self.role_maker_._barrier_all()
self.all_ips_ = self.role_maker_._all_gather(self.local_ip_)
self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_,
self.role_maker_._get_size(),
self.role_maker_._get_rank())
# barrier_all for init_worker
self.role_maker_._barrier_all()
# prepare for client to client communication
info = self._fleet_ptr.get_clients_info()
all_info = self.role_maker_._worker_gather(info[0])
self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.create_client2client_connection()
# barrier for init model
self.role_maker_._barrier_worker()
if self.role_maker_._is_first_worker():
tables = self._dist_desc.trainer_param.dense_table
for prog in programs:
prog_id = str(id(prog))
prog_conf = self._opt_info['program_configs'][prog_id]
prog_tables = {}
for key in prog_conf:
if "dense" not in key:
continue
for table_id in prog_conf[key]:
prog_tables[int(table_id)] = 0
for table in tables:
if int(table.table_id) not in prog_tables:
continue
var_name_list = []
for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i])
self._fleet_ptr.init_model(prog.desc,
int(table.table_id),
var_name_list)
# barrier for init model done
self.role_maker_._barrier_worker()
else:
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
def get_worker_num(self):
"""
return the number of current job's worker num
"""
return self.role_maker_._worker_num()
def get_server_num(self):
"""
return the number of current job's server num
"""
return self.role_maker_._server_num()
def get_worker_index(self):
"""
return the mpi rank of current worker
"""
return self.role_maker_._worker_index()
def is_worker(self):
"""
return whether current node is a worker
"""
return self.role_maker_._is_worker()
def is_server(self):
"""
return whether current node is pserver
"""
return self.role_maker_._is_server()
def init_pserver_model(self):
"""
init pserver model called from pserver
"""
if self.role_maker_._is_first_worker():
self._fleet_ptr.init_model()
self.role_maker_._barrier_worker()
def save_pserver_model(self, save_path):
"""
save pserver model called from a worker
"""
self._fleet_ptr.save_model(save_path)
def _set_opt_info(self, opt_info):
"""
this function saves the result from DistributedOptimizer.minimize()
"""
self._opt_info = opt_info
class DistributedOptimizer(object):
"""
DistributedOptimizer is a wrapper for paddle.fluid.optimizer
A user should pass a paddle.fluid.optimizer to DistributedOptimizer
minimize() function is implemented.
DistributedOptimizer is the starting point for a user who wants to
run distributed training. The optimized information will be stored in
Fleet() instance who holds the global information about current distributed
training.
"""
def __init__(self, optimizer, dist_config={}):
super(DistributedOptimizer, self).__init__()
self._optimizer = optimizer
self._optimizer_name = "Distributed%s" % optimizer.type.capitalize()
if optimizer.type != "adam":
print("Currently, distributed optimizer only supports Adam"
"Will config built-in adam for you."
"We will support more functions in DistributedOptimizer",
sys.stderr)
self._optimizer_name = "DistributedAdam"
self._distributed_optimizer = globals()[self._optimizer_name](optimizer)
def backward(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
Currently, backward function can not be called through DistributedOptimizer
"""
raise NotImplementedError()
def apply_gradients(self, params_grads):
"""
Currently, apply_gradients function can not be called through DistributedOptimizer
"""
raise NotImplementedError()
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
"""
minimize a program through loss, loss can be a list in DistributedOptimizer
Args:
loss (Variable|Variable List): loss variable or loss variable list to run optimization.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
Returns:
tuple: (optimize_ops, params_grads) which are, list of operators appended;
and list of (param, grad) Variables pair for optimization.
Note that in parameter server mode, a worker will not get anything about optimize_os
Because optmizer algorithms run on pserver side. We will make this usable in pserver
process, but currently the optimization part is written into Fleet(). A user does not
need to care about how to startup a pserver node.
"""
optimize_ops, param_grads, opt_info = \
self._distributed_optimizer._minimize(
loss,
startup_program,
parameter_list,
no_grad_set)
fleet_instance._set_opt_info(opt_info)
return [optimize_ops, param_grads]
# this is a temporary solution
# TODO(guru4elephant)
# will make this more flexible for more Parameter Server Archs
fleet_instance = Fleet()
init = fleet_instance.init
stop = fleet_instance.stop
init_pserver = fleet_instance.init_pserver
init_worker = fleet_instance.init_worker
is_worker = fleet_instance.is_worker
is_server = fleet_instance.is_server
init_pserver_model = fleet_instance.init_pserver_model
save_pserver_model = fleet_instance.save_pserver_model
worker_num = fleet_instance.get_worker_num
server_num = fleet_instance.get_server_num
worker_index = fleet_instance.get_worker_index
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import ps_pb2 as pslib
class Server(object):
"""
A Server basic class.
"""
def __init__(self):
pass
class Worker(object):
"""
A Worker basic class.
"""
def __init__(self):
pass
class DownpourServer(Server):
"""
DownpourServer class is used to generate server program_desc
Args:
server: it is pslib.ServerParameter()
Examples:
server = DownpourServer()
"""
def __init__(self):
self.server_ = pslib.ServerParameter()
self.server_.downpour_server_param.service_param.start_server_port = 0
self.server_.downpour_server_param.service_param.server_class = "DownpourBrpcPsServer"
self.server_.downpour_server_param.service_param.client_class = "DownpourBrpcPsClient"
self.server_.downpour_server_param.service_param.service_class = "DownpourPsService"
self.server_.downpour_server_param.service_param.start_server_port = 0
self.server_.downpour_server_param.service_param.server_thread_num = 12
def add_sparse_table(self, table_id, learning_rate, slot_key_vars,
slot_value_var):
"""
Args:
table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters. \
Can be a float value
slot_key_vars(string): slot key id
slot_value_var(string): slot key value after embedding
Returns:
return None
"""
table = self.server_.downpour_server_param.downpour_table_param.add()
table.table_id = table_id
table.table_class = "DownpourSparseTable"
table.type = pslib.PS_SPARSE_TABLE
table.accessor.accessor_class = "DownpourFeatureValueAccessor"
table.accessor.sparse_sgd_param.learning_rate = learning_rate
table.accessor.sparse_sgd_param.initial_g2sum = 3
table.accessor.sparse_sgd_param.initial_range = 1e-4
table.accessor.sparse_sgd_param.weight_bounds.extend([-10, 10])
table.accessor.embedx_dim = 8
table.accessor.embedx_threshold = 5
table.accessor.fea_dim = 11
table.accessor.downpour_accessor_param.nonclk_coeff = 0.1
table.accessor.downpour_accessor_param.click_coeff = 2
table.accessor.downpour_accessor_param.base_threshold = 0.2
table.accessor.downpour_accessor_param.delta_threshold = 0.15
table.accessor.downpour_accessor_param.delta_keep_days = 31
table.accessor.downpour_accessor_param.show_click_decay_rate = 0.999
table.accessor.downpour_accessor_param.delete_threshold = 0.8
def add_dense_table(self, table_id, learning_rate, param_var, grad_var):
"""
Args:
table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters. \
Can be a float value
param_var(list): all dense param. it is a list.
grad_var(list): all dense grad parm it is a list.
Returns:
return None
"""
table = self.server_.downpour_server_param.downpour_table_param.add()
table.table_id = table_id
table.table_class = "DownpourDenseTable"
table.type = pslib.PS_DENSE_TABLE
table.accessor.accessor_class = "DownpourDenseValueAccessor"
table.accessor.dense_sgd_param.name = "adam"
table.accessor.dense_sgd_param.adam.learning_rate = learning_rate
table.accessor.dense_sgd_param.adam.avg_decay_rate = 0.999993
table.accessor.dense_sgd_param.adam.ada_decay_rate = 0.9999
table.accessor.dense_sgd_param.adam.ada_epsilon = 1e-8
table.accessor.dense_sgd_param.adam.mom_decay_rate = 0.99
table.accessor.dense_sgd_param.naive.learning_rate = 0.0002
fea_dim = 0
for param in filter(lambda x: x.name.find("embedding") == -1,
param_var):
fea_dim += reduce(lambda x, y: x * y, param.shape, 1)
table.accessor.fea_dim = fea_dim
def add_data_norm_table(self, table_id, learning_rate, param_var, grad_var):
"""
Args:
table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters. \
Can be a float value
param_var(list): all dense param. it is a list.
grad_var(list): all dense grad parm it is a list.
Returns:
return None
"""
table = self.server_.downpour_server_param.downpour_table_param.add()
table.table_id = table_id
table.table_class = "DownpourDenseTable"
table.type = pslib.PS_DENSE_TABLE
table.accessor.accessor_class = "DownpourDenseValueAccessor"
table.accessor.dense_sgd_param.name = "summary"
table.accessor.dense_sgd_param.summary.summary_decay_rate = 0.999999
fea_dim = 0
for param in filter(lambda x: x.name.find("embedding") == -1,
param_var):
fea_dim += reduce(lambda x, y: x * y, param.shape, 1)
table.accessor.fea_dim = fea_dim
def get_desc(self):
"""
Return downpour server program_desc
"""
return self.server_
class DownpourWorker(Worker):
"""
DownpourWorker class is used to generate worker program_desc
Args:
window (int): push params frequency
worker: it is pslib.DownpourTrainerParameter
Examples:
worker = DownpourWorker(1)
"""
def __init__(self, window):
self.window = window
self.worker_ = pslib.DownpourTrainerParameter()
def add_sparse_table(self, table_id, learning_rate, slot_key_vars,
slot_value_vars):
"""
Args:
table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters. \
Can be a float value
slot_key_vars(string): slot key id
slot_value_var(string): slot key value after embedding
Returns:
return None
"""
table = self.worker_.sparse_table.add()
table.table_id = table_id
table.slot_key.extend([var.name for var in slot_key_vars])
table.slot_value.extend([var.name for var in slot_value_vars])
table.slot_gradient.extend(
[var.name + "@GRAD" for var in slot_value_vars])
def add_dense_table(self, table_id, learning_rate, param_vars, grad_vars):
"""
Args:
table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters. \
Can be a float value
param_var(list): all dense param. it is a list.
grad_var(list): all dense grad parm it is a list.
Returns:
return None
"""
table = self.worker_.dense_table.add()
table.table_id = table_id
table.dense_variable_name.extend(
filter(lambda x: x.find("embedding") == -1,
[p.name for p in param_vars]))
table.dense_gradient_variable_name.extend(
filter(lambda x: x.find("embedding") == -1,
[g.name for g in grad_vars]))
def get_desc(self):
"""
Return downpour worker program_desc
"""
return self.worker_
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["DistributedAdam"]
import ps_pb2 as pslib
import paddle.fluid as fluid
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_inputs
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs
from google.protobuf import text_format
from .node import DownpourWorker, DownpourServer
class DistributedOptimizerImplBase(object):
def __init__(self, optimizer):
self.optimizer_ = optimizer
self.learning_rate_ = optimizer._learning_rate
self.regularization_ = optimizer.regularization
def minimize(self,
losses,
startup_program=None,
parameter_list=None,
no_grad_set=None):
pass
class DistributedAdam(DistributedOptimizerImplBase):
def __init__(self, optimizer):
# todo(guru4elephant): add more optimizers here as argument
# todo(guru4elephant): make learning_rate as a variable
super(DistributedAdam, self).__init__(optimizer)
self.window_ = 1
self.type = "downpour"
self.data_norm_name = [
".batch_size", ".batch_square_sum", ".batch_sum",
".batch_size@GRAD", ".batch_square_sum@GRAD", ".batch_sum@GRAD"
]
def _minimize(self,
losses,
startup_program=None,
parameter_list=None,
no_grad_set=None):
"""
DownpounSGD is a distributed optimizer so
that user can call minimize to generate backward
operators and optimization operators within minmize function
Args:
loss(Variable): loss variable defined by user
startup_program(Program): startup program that defined by user
parameter_list(str list): parameter names defined by users
no_grad_set(set): a set of variables that is defined by users
so that these variables do not need gradient computation
Returns:
[optimize_ops, grads_and_weights]
"""
if not isinstance(losses, list):
losses = [losses]
table_name = find_distributed_lookup_table(losses[0].block.program)
prefetch_slots = find_distributed_lookup_table_inputs(
losses[0].block.program, table_name)
prefetch_slots_emb = find_distributed_lookup_table_outputs(
losses[0].block.program, table_name)
ps_param = pslib.PSParameter()
server = DownpourServer()
worker = DownpourWorker(self.window_)
sparse_table_index = 0
server.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb)
worker.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb)
dense_table_index = 1
program_configs = {}
param_grads_list = []
for loss_index in range(len(losses)):
#program_config = ps_param.trainer_param.program_config.add()
#program_config.program_id = str(
# id(losses[loss_index].block.program))
program_id = str(id(losses[loss_index].block.program))
program_configs[program_id] = {
"pull_sparse": [sparse_table_index],
"push_sparse": [sparse_table_index]
}
#program_config.pull_sparse_table_id.extend([sparse_table_index])
#program_config.push_sparse_table_id.extend([sparse_table_index])
params_grads = sorted(
fluid.backward.append_backward(losses[loss_index],
parameter_list, no_grad_set),
key=lambda x: x[0].name)
param_grads_list.append(params_grads)
params = []
grads = []
data_norm_params = []
data_norm_grads = []
for i in params_grads:
is_data_norm_data = False
for data_norm_name in self.data_norm_name:
if i[0].name.endswith(data_norm_name):
is_data_norm_data = True
data_norm_params.append(i[0])
if not is_data_norm_data:
params.append(i[0])
for i in params_grads:
is_data_norm_data = False
for data_norm_grad in self.data_norm_name:
if i[0].name.endswith(data_norm_grad):
is_data_norm_data = True
data_norm_grads.append(i[1])
if not is_data_norm_data:
grads.append(i[1])
server.add_dense_table(dense_table_index, self.learning_rate_,
params, grads)
worker.add_dense_table(dense_table_index, self.learning_rate_,
params, grads)
program_configs[program_id]["pull_dense"] = [dense_table_index]
program_configs[program_id]["push_dense"] = [dense_table_index]
#program_config.pull_dense_table_id.extend([dense_table_index])
#program_config.push_dense_table_id.extend([dense_table_index])
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
dense_table_index += 1
server.add_data_norm_table(dense_table_index,
self.learning_rate_,
data_norm_params, data_norm_grads)
worker.add_dense_table(dense_table_index, self.learning_rate_,
data_norm_params, data_norm_grads)
#program_config.pull_dense_table_id.extend([dense_table_index])
#program_config.push_dense_table_id.extend([dense_table_index])
program_configs[program_id]["pull_dense"].extend(
[dense_table_index])
program_configs[program_id]["push_dense"].extend(
[dense_table_index])
dense_table_index += 1
#program_configs.append(program_config)
ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc())
#for program_config in program_configs:
# ps_param.trainer_param.program_config.extend([program_config])
# Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
opt_info = {}
opt_info["program_configs"] = program_configs
opt_info["trainer"] = "DistMultiTrainer"
opt_info["device_worker"] = "DownpourSGD"
opt_info["optimizer"] = "DownpourSGD"
opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops
for loss in losses:
loss.block.program._fleet_opt = opt_info
return None, param_grads_list[0], opt_info
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: ps.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='ps.proto',
package='paddle',
syntax='proto2',
serialized_pb=_b(
'\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xbf\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x12\n\nshared_num\x18\x03 \x01(\x04\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xce\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_TABLETYPE = _descriptor.EnumDescriptor(
name='TableType',
full_name='paddle.TableType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='PS_SPARSE_TABLE', index=0, number=0, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='PS_DENSE_TABLE', index=1, number=1, options=None, type=None),
],
containing_type=None,
options=None,
serialized_start=3489,
serialized_end=3541, )
_sym_db.RegisterEnumDescriptor(_TABLETYPE)
TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE)
_PSCMDID = _descriptor.EnumDescriptor(
name='PsCmdID',
full_name='paddle.PsCmdID',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='PS_PULL_DENSE_TABLE',
index=0,
number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_PUSH_DENSE_TABLE',
index=1,
number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_PULL_SPARSE_TABLE',
index=2,
number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_PUSH_SPARSE_TABLE',
index=3,
number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_SHRINK_TABLE', index=4, number=4, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='PS_SAVE_ONE_TABLE',
index=5,
number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_SAVE_ALL_TABLE',
index=6,
number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_LOAD_ONE_TABLE',
index=7,
number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_LOAD_ALL_TABLE',
index=8,
number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_CLEAR_ONE_TABLE',
index=9,
number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_CLEAR_ALL_TABLE',
index=10,
number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_PUSH_DENSE_PARAM',
index=11,
number=11,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_STOP_SERVER', index=12, number=12, options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=3544,
serialized_end=3861, )
_sym_db.RegisterEnumDescriptor(_PSCMDID)
PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID)
PS_SPARSE_TABLE = 0
PS_DENSE_TABLE = 1
PS_PULL_DENSE_TABLE = 0
PS_PUSH_DENSE_TABLE = 1
PS_PULL_SPARSE_TABLE = 2
PS_PUSH_SPARSE_TABLE = 3
PS_SHRINK_TABLE = 4
PS_SAVE_ONE_TABLE = 5
PS_SAVE_ALL_TABLE = 6
PS_LOAD_ONE_TABLE = 7
PS_LOAD_ALL_TABLE = 8
PS_CLEAR_ONE_TABLE = 9
PS_CLEAR_ALL_TABLE = 10
PS_PUSH_DENSE_PARAM = 11
PS_STOP_SERVER = 12
_FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor(
name='FsApiType',
full_name='paddle.FsClientParameter.FsApiType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='HDFS', index=0, number=0, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='AFS', index=1, number=1, options=None, type=None),
],
containing_type=None,
options=None,
serialized_start=3457,
serialized_end=3487, )
_sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE)
_PSPARAMETER = _descriptor.Descriptor(
name='PSParameter',
full_name='paddle.PSParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='worker_class',
full_name='paddle.PSParameter.worker_class',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='server_class',
full_name='paddle.PSParameter.server_class',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='instance_class',
full_name='paddle.PSParameter.instance_class',
index=2,
number=3,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='worker_param',
full_name='paddle.PSParameter.worker_param',
index=3,
number=101,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='server_param',
full_name='paddle.PSParameter.server_param',
index=4,
number=102,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='trainer_param',
full_name='paddle.PSParameter.trainer_param',
index=5,
number=301,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='fs_client_param',
full_name='paddle.PSParameter.fs_client_param',
index=6,
number=501,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=21,
serialized_end=307, )
_WORKERPARAMETER = _descriptor.Descriptor(
name='WorkerParameter',
full_name='paddle.WorkerParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='downpour_worker_param',
full_name='paddle.WorkerParameter.downpour_worker_param',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=309,
serialized_end=390, )
_SERVERPARAMETER = _descriptor.Descriptor(
name='ServerParameter',
full_name='paddle.ServerParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='downpour_server_param',
full_name='paddle.ServerParameter.downpour_server_param',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=392,
serialized_end=473, )
_DOWNPOURWORKERPARAMETER = _descriptor.Descriptor(
name='DownpourWorkerParameter',
full_name='paddle.DownpourWorkerParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='downpour_table_param',
full_name='paddle.DownpourWorkerParameter.downpour_table_param',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=475,
serialized_end=554, )
_DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor(
name='DownpourTrainerParameter',
full_name='paddle.DownpourTrainerParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='dense_table',
full_name='paddle.DownpourTrainerParameter.dense_table',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='sparse_table',
full_name='paddle.DownpourTrainerParameter.sparse_table',
index=1,
number=2,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='push_sparse_per_batch',
full_name='paddle.DownpourTrainerParameter.push_sparse_per_batch',
index=2,
number=3,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='push_dense_per_batch',
full_name='paddle.DownpourTrainerParameter.push_dense_per_batch',
index=3,
number=4,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='skip_op',
full_name='paddle.DownpourTrainerParameter.skip_op',
index=4,
number=5,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='program_config',
full_name='paddle.DownpourTrainerParameter.program_config',
index=5,
number=6,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=557,
serialized_end=810, )
_PROGRAMCONFIG = _descriptor.Descriptor(
name='ProgramConfig',
full_name='paddle.ProgramConfig',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='program_id',
full_name='paddle.ProgramConfig.program_id',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='push_sparse_table_id',
full_name='paddle.ProgramConfig.push_sparse_table_id',
index=1,
number=2,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='push_dense_table_id',
full_name='paddle.ProgramConfig.push_dense_table_id',
index=2,
number=3,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='pull_sparse_table_id',
full_name='paddle.ProgramConfig.pull_sparse_table_id',
index=3,
number=4,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='pull_dense_table_id',
full_name='paddle.ProgramConfig.pull_dense_table_id',
index=4,
number=5,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=813,
serialized_end=966, )
_DENSETABLEPARAMETER = _descriptor.Descriptor(
name='DenseTableParameter',
full_name='paddle.DenseTableParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='table_id',
full_name='paddle.DenseTableParameter.table_id',
index=0,
number=1,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dense_variable_name',
full_name='paddle.DenseTableParameter.dense_variable_name',
index=1,
number=2,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dense_gradient_variable_name',
full_name='paddle.DenseTableParameter.dense_gradient_variable_name',
index=2,
number=3,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='fea_dim',
full_name='paddle.DenseTableParameter.fea_dim',
index=3,
number=4,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=968,
serialized_end=1091, )
_SPARSETABLEPARAMETER = _descriptor.Descriptor(
name='SparseTableParameter',
full_name='paddle.SparseTableParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='table_id',
full_name='paddle.SparseTableParameter.table_id',
index=0,
number=1,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='feature_dim',
full_name='paddle.SparseTableParameter.feature_dim',
index=1,
number=2,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='slot_key',
full_name='paddle.SparseTableParameter.slot_key',
index=2,
number=3,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='slot_value',
full_name='paddle.SparseTableParameter.slot_value',
index=3,
number=4,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='slot_gradient',
full_name='paddle.SparseTableParameter.slot_gradient',
index=4,
number=5,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1093,
serialized_end=1215, )
_DOWNPOURSERVERPARAMETER = _descriptor.Descriptor(
name='DownpourServerParameter',
full_name='paddle.DownpourServerParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='downpour_table_param',
full_name='paddle.DownpourServerParameter.downpour_table_param',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='service_param',
full_name='paddle.DownpourServerParameter.service_param',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1218,
serialized_end=1352, )
_SERVERSERVICEPARAMETER = _descriptor.Descriptor(
name='ServerServiceParameter',
full_name='paddle.ServerServiceParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='server_class',
full_name='paddle.ServerServiceParameter.server_class',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=True,
default_value=_b("DownpourBrpcPsServer").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='client_class',
full_name='paddle.ServerServiceParameter.client_class',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=True,
default_value=_b("DownpourBrpcPsClient").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='service_class',
full_name='paddle.ServerServiceParameter.service_class',
index=2,
number=3,
type=9,
cpp_type=9,
label=1,
has_default_value=True,
default_value=_b("DownpourPsService").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='start_server_port',
full_name='paddle.ServerServiceParameter.start_server_port',
index=3,
number=4,
type=13,
cpp_type=3,
label=1,
has_default_value=True,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='server_thread_num',
full_name='paddle.ServerServiceParameter.server_thread_num',
index=4,
number=5,
type=13,
cpp_type=3,
label=1,
has_default_value=True,
default_value=12,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1355,
serialized_end=1570, )
_TABLEPARAMETER = _descriptor.Descriptor(
name='TableParameter',
full_name='paddle.TableParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='table_id',
full_name='paddle.TableParameter.table_id',
index=0,
number=1,
type=4,
cpp_type=4,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='table_class',
full_name='paddle.TableParameter.table_class',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='shared_num',
full_name='paddle.TableParameter.shared_num',
index=2,
number=3,
type=4,
cpp_type=4,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='accessor',
full_name='paddle.TableParameter.accessor',
index=3,
number=4,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.TableParameter.type',
index=4,
number=5,
type=14,
cpp_type=8,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='compress_in_save',
full_name='paddle.TableParameter.compress_in_save',
index=5,
number=6,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1573,
serialized_end=1764, )
_TABLEACCESSORPARAMETER = _descriptor.Descriptor(
name='TableAccessorParameter',
full_name='paddle.TableAccessorParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='accessor_class',
full_name='paddle.TableAccessorParameter.accessor_class',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='sparse_sgd_param',
full_name='paddle.TableAccessorParameter.sparse_sgd_param',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dense_sgd_param',
full_name='paddle.TableAccessorParameter.dense_sgd_param',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='fea_dim',
full_name='paddle.TableAccessorParameter.fea_dim',
index=3,
number=4,
type=13,
cpp_type=3,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='embedx_dim',
full_name='paddle.TableAccessorParameter.embedx_dim',
index=4,
number=5,
type=13,
cpp_type=3,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='embedx_threshold',
full_name='paddle.TableAccessorParameter.embedx_threshold',
index=5,
number=6,
type=13,
cpp_type=3,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='downpour_accessor_param',
full_name='paddle.TableAccessorParameter.downpour_accessor_param',
index=6,
number=7,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='table_accessor_save_param',
full_name='paddle.TableAccessorParameter.table_accessor_save_param',
index=7,
number=8,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1767,
serialized_end=2136, )
_DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
name='DownpourTableAccessorParameter',
full_name='paddle.DownpourTableAccessorParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='nonclk_coeff',
full_name='paddle.DownpourTableAccessorParameter.nonclk_coeff',
index=0,
number=1,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='click_coeff',
full_name='paddle.DownpourTableAccessorParameter.click_coeff',
index=1,
number=2,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='base_threshold',
full_name='paddle.DownpourTableAccessorParameter.base_threshold',
index=2,
number=3,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='delta_threshold',
full_name='paddle.DownpourTableAccessorParameter.delta_threshold',
index=3,
number=4,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='delta_keep_days',
full_name='paddle.DownpourTableAccessorParameter.delta_keep_days',
index=4,
number=5,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='show_click_decay_rate',
full_name='paddle.DownpourTableAccessorParameter.show_click_decay_rate',
index=5,
number=6,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='delete_threshold',
full_name='paddle.DownpourTableAccessorParameter.delete_threshold',
index=6,
number=7,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2139,
serialized_end=2345, )
_TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
name='TableAccessorSaveParameter',
full_name='paddle.TableAccessorSaveParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='param',
full_name='paddle.TableAccessorSaveParameter.param',
index=0,
number=1,
type=13,
cpp_type=3,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='converter',
full_name='paddle.TableAccessorSaveParameter.converter',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='deconverter',
full_name='paddle.TableAccessorSaveParameter.deconverter',
index=2,
number=3,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2347,
serialized_end=2430, )
_PSREQUESTMESSAGE = _descriptor.Descriptor(
name='PsRequestMessage',
full_name='paddle.PsRequestMessage',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='cmd_id',
full_name='paddle.PsRequestMessage.cmd_id',
index=0,
number=1,
type=13,
cpp_type=3,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='table_id',
full_name='paddle.PsRequestMessage.table_id',
index=1,
number=2,
type=13,
cpp_type=3,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='params',
full_name='paddle.PsRequestMessage.params',
index=2,
number=3,
type=12,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='client_id',
full_name='paddle.PsRequestMessage.client_id',
index=3,
number=4,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data',
full_name='paddle.PsRequestMessage.data',
index=4,
number=5,
type=12,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b(""),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2432,
serialized_end=2533, )
_SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
name='SparseSGDRuleParameter',
full_name='paddle.SparseSGDRuleParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='learning_rate',
full_name='paddle.SparseSGDRuleParameter.learning_rate',
index=0,
number=1,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='initial_g2sum',
full_name='paddle.SparseSGDRuleParameter.initial_g2sum',
index=1,
number=2,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='initial_range',
full_name='paddle.SparseSGDRuleParameter.initial_range',
index=2,
number=3,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='weight_bounds',
full_name='paddle.SparseSGDRuleParameter.weight_bounds',
index=3,
number=4,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2535,
serialized_end=2654, )
_DENSESGDRULEPARAMETER = _descriptor.Descriptor(
name='DenseSGDRuleParameter',
full_name='paddle.DenseSGDRuleParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.DenseSGDRuleParameter.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='adam',
full_name='paddle.DenseSGDRuleParameter.adam',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='naive',
full_name='paddle.DenseSGDRuleParameter.naive',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='summary',
full_name='paddle.DenseSGDRuleParameter.summary',
index=3,
number=4,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='moving_average',
full_name='paddle.DenseSGDRuleParameter.moving_average',
index=4,
number=5,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2657,
serialized_end=2882, )
_ADAMSGDPARAMETER = _descriptor.Descriptor(
name='AdamSGDParameter',
full_name='paddle.AdamSGDParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='learning_rate',
full_name='paddle.AdamSGDParameter.learning_rate',
index=0,
number=1,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='avg_decay_rate',
full_name='paddle.AdamSGDParameter.avg_decay_rate',
index=1,
number=2,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ada_decay_rate',
full_name='paddle.AdamSGDParameter.ada_decay_rate',
index=2,
number=3,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ada_epsilon',
full_name='paddle.AdamSGDParameter.ada_epsilon',
index=3,
number=4,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='mom_decay_rate',
full_name='paddle.AdamSGDParameter.mom_decay_rate',
index=4,
number=5,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2885,
serialized_end=3019, )
_NAIVESGDPARAMETER = _descriptor.Descriptor(
name='NaiveSGDParameter',
full_name='paddle.NaiveSGDParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='learning_rate',
full_name='paddle.NaiveSGDParameter.learning_rate',
index=0,
number=1,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='avg_decay_rate',
full_name='paddle.NaiveSGDParameter.avg_decay_rate',
index=1,
number=2,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3021,
serialized_end=3087, )
_SUMMARYSGDPARAMETER = _descriptor.Descriptor(
name='SummarySGDParameter',
full_name='paddle.SummarySGDParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='summary_decay_rate',
full_name='paddle.SummarySGDParameter.summary_decay_rate',
index=0,
number=1,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.999999),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3089,
serialized_end=3148, )
_MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
name='MovingAverageRuleParameter',
full_name='paddle.MovingAverageRuleParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='momentum',
full_name='paddle.MovingAverageRuleParameter.momentum',
index=0,
number=1,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3150,
serialized_end=3196, )
_PSRESPONSEMESSAGE = _descriptor.Descriptor(
name='PsResponseMessage',
full_name='paddle.PsResponseMessage',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='err_code',
full_name='paddle.PsResponseMessage.err_code',
index=0,
number=1,
type=5,
cpp_type=1,
label=2,
has_default_value=True,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='err_msg',
full_name='paddle.PsResponseMessage.err_msg',
index=1,
number=2,
type=9,
cpp_type=9,
label=2,
has_default_value=True,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data',
full_name='paddle.PsResponseMessage.data',
index=2,
number=3,
type=12,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b(""),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3198,
serialized_end=3271, )
_FSCLIENTPARAMETER = _descriptor.Descriptor(
name='FsClientParameter',
full_name='paddle.FsClientParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='fs_type',
full_name='paddle.FsClientParameter.fs_type',
index=0,
number=1,
type=14,
cpp_type=8,
label=1,
has_default_value=True,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='uri',
full_name='paddle.FsClientParameter.uri',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='user',
full_name='paddle.FsClientParameter.user',
index=2,
number=3,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='passwd',
full_name='paddle.FsClientParameter.passwd',
index=3,
number=4,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='buffer_size',
full_name='paddle.FsClientParameter.buffer_size',
index=4,
number=5,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='hadoop_bin',
full_name='paddle.FsClientParameter.hadoop_bin',
index=5,
number=51,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='afs_conf',
full_name='paddle.FsClientParameter.afs_conf',
index=6,
number=101,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[_FSCLIENTPARAMETER_FSAPITYPE, ],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3274,
serialized_end=3487, )
_PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER
_PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER
_PSPARAMETER.fields_by_name[
'trainer_param'].message_type = _DOWNPOURTRAINERPARAMETER
_PSPARAMETER.fields_by_name['fs_client_param'].message_type = _FSCLIENTPARAMETER
_WORKERPARAMETER.fields_by_name[
'downpour_worker_param'].message_type = _DOWNPOURWORKERPARAMETER
_SERVERPARAMETER.fields_by_name[
'downpour_server_param'].message_type = _DOWNPOURSERVERPARAMETER
_DOWNPOURWORKERPARAMETER.fields_by_name[
'downpour_table_param'].message_type = _TABLEPARAMETER
_DOWNPOURTRAINERPARAMETER.fields_by_name[
'dense_table'].message_type = _DENSETABLEPARAMETER
_DOWNPOURTRAINERPARAMETER.fields_by_name[
'sparse_table'].message_type = _SPARSETABLEPARAMETER
_DOWNPOURTRAINERPARAMETER.fields_by_name[
'program_config'].message_type = _PROGRAMCONFIG
_DOWNPOURSERVERPARAMETER.fields_by_name[
'downpour_table_param'].message_type = _TABLEPARAMETER
_DOWNPOURSERVERPARAMETER.fields_by_name[
'service_param'].message_type = _SERVERSERVICEPARAMETER
_TABLEPARAMETER.fields_by_name[
'accessor'].message_type = _TABLEACCESSORPARAMETER
_TABLEPARAMETER.fields_by_name['type'].enum_type = _TABLETYPE
_TABLEACCESSORPARAMETER.fields_by_name[
'sparse_sgd_param'].message_type = _SPARSESGDRULEPARAMETER
_TABLEACCESSORPARAMETER.fields_by_name[
'dense_sgd_param'].message_type = _DENSESGDRULEPARAMETER
_TABLEACCESSORPARAMETER.fields_by_name[
'downpour_accessor_param'].message_type = _DOWNPOURTABLEACCESSORPARAMETER
_TABLEACCESSORPARAMETER.fields_by_name[
'table_accessor_save_param'].message_type = _TABLEACCESSORSAVEPARAMETER
_DENSESGDRULEPARAMETER.fields_by_name['adam'].message_type = _ADAMSGDPARAMETER
_DENSESGDRULEPARAMETER.fields_by_name['naive'].message_type = _NAIVESGDPARAMETER
_DENSESGDRULEPARAMETER.fields_by_name[
'summary'].message_type = _SUMMARYSGDPARAMETER
_DENSESGDRULEPARAMETER.fields_by_name[
'moving_average'].message_type = _MOVINGAVERAGERULEPARAMETER
_FSCLIENTPARAMETER.fields_by_name[
'fs_type'].enum_type = _FSCLIENTPARAMETER_FSAPITYPE
_FSCLIENTPARAMETER_FSAPITYPE.containing_type = _FSCLIENTPARAMETER
DESCRIPTOR.message_types_by_name['PSParameter'] = _PSPARAMETER
DESCRIPTOR.message_types_by_name['WorkerParameter'] = _WORKERPARAMETER
DESCRIPTOR.message_types_by_name['ServerParameter'] = _SERVERPARAMETER
DESCRIPTOR.message_types_by_name[
'DownpourWorkerParameter'] = _DOWNPOURWORKERPARAMETER
DESCRIPTOR.message_types_by_name[
'DownpourTrainerParameter'] = _DOWNPOURTRAINERPARAMETER
DESCRIPTOR.message_types_by_name['ProgramConfig'] = _PROGRAMCONFIG
DESCRIPTOR.message_types_by_name['DenseTableParameter'] = _DENSETABLEPARAMETER
DESCRIPTOR.message_types_by_name['SparseTableParameter'] = _SPARSETABLEPARAMETER
DESCRIPTOR.message_types_by_name[
'DownpourServerParameter'] = _DOWNPOURSERVERPARAMETER
DESCRIPTOR.message_types_by_name[
'ServerServiceParameter'] = _SERVERSERVICEPARAMETER
DESCRIPTOR.message_types_by_name['TableParameter'] = _TABLEPARAMETER
DESCRIPTOR.message_types_by_name[
'TableAccessorParameter'] = _TABLEACCESSORPARAMETER
DESCRIPTOR.message_types_by_name[
'DownpourTableAccessorParameter'] = _DOWNPOURTABLEACCESSORPARAMETER
DESCRIPTOR.message_types_by_name[
'TableAccessorSaveParameter'] = _TABLEACCESSORSAVEPARAMETER
DESCRIPTOR.message_types_by_name['PsRequestMessage'] = _PSREQUESTMESSAGE
DESCRIPTOR.message_types_by_name[
'SparseSGDRuleParameter'] = _SPARSESGDRULEPARAMETER
DESCRIPTOR.message_types_by_name[
'DenseSGDRuleParameter'] = _DENSESGDRULEPARAMETER
DESCRIPTOR.message_types_by_name['AdamSGDParameter'] = _ADAMSGDPARAMETER
DESCRIPTOR.message_types_by_name['NaiveSGDParameter'] = _NAIVESGDPARAMETER
DESCRIPTOR.message_types_by_name['SummarySGDParameter'] = _SUMMARYSGDPARAMETER
DESCRIPTOR.message_types_by_name[
'MovingAverageRuleParameter'] = _MOVINGAVERAGERULEPARAMETER
DESCRIPTOR.message_types_by_name['PsResponseMessage'] = _PSRESPONSEMESSAGE
DESCRIPTOR.message_types_by_name['FsClientParameter'] = _FSCLIENTPARAMETER
DESCRIPTOR.enum_types_by_name['TableType'] = _TABLETYPE
DESCRIPTOR.enum_types_by_name['PsCmdID'] = _PSCMDID
PSParameter = _reflection.GeneratedProtocolMessageType(
'PSParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_PSPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.PSParameter)
))
_sym_db.RegisterMessage(PSParameter)
WorkerParameter = _reflection.GeneratedProtocolMessageType(
'WorkerParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_WORKERPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.WorkerParameter)
))
_sym_db.RegisterMessage(WorkerParameter)
ServerParameter = _reflection.GeneratedProtocolMessageType(
'ServerParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_SERVERPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.ServerParameter)
))
_sym_db.RegisterMessage(ServerParameter)
DownpourWorkerParameter = _reflection.GeneratedProtocolMessageType(
'DownpourWorkerParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_DOWNPOURWORKERPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.DownpourWorkerParameter)
))
_sym_db.RegisterMessage(DownpourWorkerParameter)
DownpourTrainerParameter = _reflection.GeneratedProtocolMessageType(
'DownpourTrainerParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_DOWNPOURTRAINERPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.DownpourTrainerParameter)
))
_sym_db.RegisterMessage(DownpourTrainerParameter)
ProgramConfig = _reflection.GeneratedProtocolMessageType(
'ProgramConfig',
(_message.Message, ),
dict(
DESCRIPTOR=_PROGRAMCONFIG,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.ProgramConfig)
))
_sym_db.RegisterMessage(ProgramConfig)
DenseTableParameter = _reflection.GeneratedProtocolMessageType(
'DenseTableParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_DENSETABLEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.DenseTableParameter)
))
_sym_db.RegisterMessage(DenseTableParameter)
SparseTableParameter = _reflection.GeneratedProtocolMessageType(
'SparseTableParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_SPARSETABLEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.SparseTableParameter)
))
_sym_db.RegisterMessage(SparseTableParameter)
DownpourServerParameter = _reflection.GeneratedProtocolMessageType(
'DownpourServerParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_DOWNPOURSERVERPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.DownpourServerParameter)
))
_sym_db.RegisterMessage(DownpourServerParameter)
ServerServiceParameter = _reflection.GeneratedProtocolMessageType(
'ServerServiceParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_SERVERSERVICEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.ServerServiceParameter)
))
_sym_db.RegisterMessage(ServerServiceParameter)
TableParameter = _reflection.GeneratedProtocolMessageType(
'TableParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_TABLEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.TableParameter)
))
_sym_db.RegisterMessage(TableParameter)
TableAccessorParameter = _reflection.GeneratedProtocolMessageType(
'TableAccessorParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_TABLEACCESSORPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.TableAccessorParameter)
))
_sym_db.RegisterMessage(TableAccessorParameter)
DownpourTableAccessorParameter = _reflection.GeneratedProtocolMessageType(
'DownpourTableAccessorParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_DOWNPOURTABLEACCESSORPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.DownpourTableAccessorParameter)
))
_sym_db.RegisterMessage(DownpourTableAccessorParameter)
TableAccessorSaveParameter = _reflection.GeneratedProtocolMessageType(
'TableAccessorSaveParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_TABLEACCESSORSAVEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.TableAccessorSaveParameter)
))
_sym_db.RegisterMessage(TableAccessorSaveParameter)
PsRequestMessage = _reflection.GeneratedProtocolMessageType(
'PsRequestMessage',
(_message.Message, ),
dict(
DESCRIPTOR=_PSREQUESTMESSAGE,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.PsRequestMessage)
))
_sym_db.RegisterMessage(PsRequestMessage)
SparseSGDRuleParameter = _reflection.GeneratedProtocolMessageType(
'SparseSGDRuleParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_SPARSESGDRULEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.SparseSGDRuleParameter)
))
_sym_db.RegisterMessage(SparseSGDRuleParameter)
DenseSGDRuleParameter = _reflection.GeneratedProtocolMessageType(
'DenseSGDRuleParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_DENSESGDRULEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.DenseSGDRuleParameter)
))
_sym_db.RegisterMessage(DenseSGDRuleParameter)
AdamSGDParameter = _reflection.GeneratedProtocolMessageType(
'AdamSGDParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_ADAMSGDPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.AdamSGDParameter)
))
_sym_db.RegisterMessage(AdamSGDParameter)
NaiveSGDParameter = _reflection.GeneratedProtocolMessageType(
'NaiveSGDParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_NAIVESGDPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.NaiveSGDParameter)
))
_sym_db.RegisterMessage(NaiveSGDParameter)
SummarySGDParameter = _reflection.GeneratedProtocolMessageType(
'SummarySGDParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_SUMMARYSGDPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.SummarySGDParameter)
))
_sym_db.RegisterMessage(SummarySGDParameter)
MovingAverageRuleParameter = _reflection.GeneratedProtocolMessageType(
'MovingAverageRuleParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_MOVINGAVERAGERULEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.MovingAverageRuleParameter)
))
_sym_db.RegisterMessage(MovingAverageRuleParameter)
PsResponseMessage = _reflection.GeneratedProtocolMessageType(
'PsResponseMessage',
(_message.Message, ),
dict(
DESCRIPTOR=_PSRESPONSEMESSAGE,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.PsResponseMessage)
))
_sym_db.RegisterMessage(PsResponseMessage)
FsClientParameter = _reflection.GeneratedProtocolMessageType(
'FsClientParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_FSCLIENTPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.FsClientParameter)
))
_sym_db.RegisterMessage(FsClientParameter)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(),
_b('\200\001\001'))
# @@protoc_insertion_point(module_scope)
...@@ -81,62 +81,6 @@ class TestAsyncExecutor(unittest.TestCase): ...@@ -81,62 +81,6 @@ class TestAsyncExecutor(unittest.TestCase):
tarf.extractall(path='./') tarf.extractall(path='./')
tarf.close() tarf.close()
def test_data_feed_desc(self):
data_feed = fluid.DataFeedDesc('./data.prototxt')
# assertEqueal(data_feed.proto_desc.batch, 2)
# assertEqual(len(data_feed.proto_desc.multi_slot_desc), 2)
self.assertEqual(" ".join(data_feed.desc().split()),
" ".join(proto_str.split()))
def test_run(self):
# Initialize dataset description
data_feed = fluid.DataFeedDesc('train_data/data.prototxt')
data_feed.set_batch_size(
128) # See API doc for how to change other fields
# define network
# input text data
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
# label data
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost, acc, prediction = bow_net(data, label)
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002)
opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost)
# Run startup program
startup_program = fluid.default_startup_program()
place = fluid.CPUPlace()
executor = fluid.Executor(place)
executor.run(startup_program)
main_program = fluid.default_main_program()
async_executor = fluid.AsyncExecutor(place)
self.assertRaises(TypeError, async_executor.run)
self.assertRaises(TypeError, async_executor.run, main_program)
self.assertRaises(TypeError, async_executor.run, main_program,
data_feed)
filelist = ['train_data/part-%d' % i for i in range(10)]
self.assertRaises(TypeError, async_executor.run, main_program,
data_feed, filelist)
thread_num = 4
self.assertRaises(TypeError, async_executor.run, main_program,
data_feed, filelist, thread_num)
async_executor.run(main_program, data_feed, filelist, thread_num, [acc])
fluid.io.save_inference_model("imdb.model", [data.name, label.name],
[acc], executor)
statinfo = os.stat('imdb.model/__model__')
self.assertGreater(statinfo.st_size, 0)
os.remove('./data.prototxt')
shutil.rmtree('./train_data')
shutil.rmtree('./imdb.model')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TestCases for Dataset,
including create, config, run, etc.
"""
from __future__ import print_function
import paddle.fluid as fluid
import numpy as np
import os
import shutil
import unittest
class TestDataset(unittest.TestCase):
""" TestCases for Dataset. """
def test_dataset_create(self):
""" Testcase for dataset create. """
return
try:
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
except:
self.assertTrue(False)
try:
dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
except:
self.assertTrue(False)
try:
dataset = fluid.DatasetFactory().create_dataset("MyOwnDataset")
self.assertTrue(False)
except:
self.assertTrue(True)
def test_dataset_config(self):
""" Testcase for dataset configuration. """
return
dataset = fluid.core.Dataset("MultiSlotDataset")
dataset.set_thread_num(12)
dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
dataset.set_trainer_num(4)
dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
thread_num = dataset.get_thread_num()
self.assertEqual(thread_num, 12)
filelist = dataset.get_filelist()
self.assertEqual(len(filelist), 3)
self.assertEqual(filelist[0], "a.txt")
self.assertEqual(filelist[1], "b.txt")
self.assertEqual(filelist[2], "c.txt")
trainer_num = dataset.get_trainer_num()
self.assertEqual(trainer_num, 4)
name, ugi = dataset.get_hdfs_config()
self.assertEqual(name, "my_fs_name")
self.assertEqual(ugi, "my_fs_ugi")
def test_in_memory_dataset_run(self):
"""
Testcase for InMemoryDataset from create to run.
"""
return
with open("test_in_memory_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_in_memory_dataset_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(
name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist([
"test_in_memory_dataset_run_a.txt",
"test_in_memory_dataset_run_b.txt"
])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
dataset.local_shuffle()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
for i in range(2):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
except:
#self.assertTrue(False)
pass
os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt")
def test_queue_dataset_run(self):
"""
Testcase for QueueDataset from create to run.
"""
return
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_queue_dataset_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(
name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist(
["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
for i in range(2):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
except:
#self.assertTrue(False)
pass
os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt")
if __name__ == '__main__':
#unittest.main()
import sys
sys.exit(0)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer']
# can be initialized from train_desc,
class TrainerDesc(object):
def __init__(self):
'''
self.proto_desc = data_feed_pb2.DataFeedDesc()
with open(proto_file, 'r') as f:
text_format.Parse(f.read(), self.proto_desc)
'''
from proto import trainer_desc_pb2
self.proto_desc = trainer_desc_pb2.TrainerDesc()
import multiprocessing as mp
# set default thread num == cpu count
self.proto_desc.thread_num = mp.cpu_count()
self.fleet_desc_ = None
self.device_worker_ = None
self.program_ = None
self.infer_ = False
def _set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for i, v in enumerate(fetch_vars):
self.proto_desc.fetch_config.fetch_var_names.extend([v.name])
self.proto_desc.fetch_config.fetch_var_str_format.extend(
[fetch_info[i]])
self.proto_desc.fetch_config.print_period = print_period
def _set_debug(self, debug):
self.proto_desc.debug = debug
def _set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num
def _set_device_worker(self, device_worker):
self.device_worker_ = device_worker
def _set_infer(self, infer):
self.infer_ = infer
def _set_fleet_desc(self, fleet_desc):
self.fleet_desc_ = fleet_desc
def _gen_trainer_desc(self):
pass
def _set_program(self, program):
self.program_ = program
def _desc(self):
from google.protobuf import text_format
return text_format.MessageToString(self.proto_desc)
class MultiTrainer(TrainerDesc):
def __init__(self):
super(MultiTrainer, self).__init__()
pass
def _set_program(self, program):
super(MultiTrainer, self)._set_program(program)
self.program_ = program
def _gen_trainer_desc(self):
super(MultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_._set_infer(self.infer_)
self.device_worker_._gen_worker_desc(self.proto_desc)
class DistMultiTrainer(TrainerDesc):
def __init__(self):
super(DistMultiTrainer, self).__init__()
pass
def _set_program(self, program):
super(DistMultiTrainer, self)._set_program(program)
self.program_ = program
def _gen_trainer_desc(self):
super(DistMultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None:
raise RuntimeError("None Program")
self.device_worker_._set_infer(self.infer_)
self.device_worker_._set_program(self.program_)
self.device_worker_._gen_worker_desc(self.proto_desc)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["TrainerFactory"]
class TrainerFactory(object):
def __init__(self):
pass
def _create_trainer(self, opt_info=None):
from .trainer_desc import MultiTrainer, DistMultiTrainer
from .device_worker import Hogwild, DownpourSGD
trainer = None
device_worker = None
if opt_info == None:
# default is MultiTrainer + Hogwild
trainer = MultiTrainer()
device_worker = Hogwild()
trainer._set_device_worker(device_worker)
else:
trainer_class = opt_info["trainer"]
device_worker_class = opt_info["device_worker"]
trainer = globals()[trainer_class]()
device_worker = globals()[device_worker_class]()
device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_device_worker(device_worker)
trainer._set_fleet_desc(opt_info["fleet_desc"])
return trainer
...@@ -121,7 +121,13 @@ packages=['paddle', ...@@ -121,7 +121,13 @@ packages=['paddle',
'paddle.fluid.contrib.utils', 'paddle.fluid.contrib.utils',
'paddle.fluid.contrib.extend_optimizer', 'paddle.fluid.contrib.extend_optimizer',
'paddle.fluid.transpiler', 'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details'] 'paddle.fluid.transpiler.details',
'paddle.fluid.incubate',
'paddle.fluid.incubate.data_generator',
'paddle.fluid.incubate.fleet',
'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.parameter_server',
'paddle.fluid.incubate.fleet.p2p']
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f: with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
setup_requires = f.read().splitlines() setup_requires = f.read().splitlines()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册