提交 9bca1926 编写于 作者: H heqiaozhi 提交者: dongdaxiang

refactor & fix bug

上级 2e9a836c
......@@ -29,6 +29,7 @@ add_subdirectory(io)
#ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
......@@ -174,12 +175,19 @@ endif()
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} trainer_library)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog fleet_wrapper
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper trainer_library data_feed_proto ${NGRAPH_EXE_DEPS})
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc multi_trainer.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer)
cc_library(executor SRCS executor.cc multi_trainer.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
......@@ -190,8 +198,6 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper)
<<<<<<< HEAD
=======
if(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
......@@ -201,7 +207,7 @@ if(WITH_PSLIB)
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer)
variable_helper pslib_brpc pslib timer fs shell)
else()
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
......@@ -211,18 +217,9 @@ else()
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper timer)
variable_helper timer fs shell)
endif(WITH_PSLIB)
>>>>>>> 870b88bbd7... add DataSet and InMemoryDataFeed, support load data into memory and shuffle data
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
variable_helper timer)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto)
......
......@@ -220,111 +220,8 @@ void InMemoryDataFeed<T>::LocalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
}
// todo global shuffle
/*
template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
for (int64_t i = 0; i < memory_data_.size(); ++i) {
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t hash_id = paddle::ps::local_random_engine()();
//int64_t hash_id = hash(ins_id);
int64_t node_id = hash_id % trainer_num_;
std::string str;
SerializeIns(memory_data_[i], str);
auto fleet_ptr = FleetWrapper::GetInstance();
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str);
}
}
*/
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_ = 0;
shuffled_ins_ = nullptr;
shuffled_ins_out_ = nullptr;
}
template <typename T>
bool InMemoryDataFeed<T>::Start() {
DataFeed::CheckSetFileList();
if (memory_data_.size() != 0) {
CHECK_EQ(cur_channel_, 0);
shuffled_ins_->Extend(std::move(memory_data_));
std::vector<T>().swap(memory_data_);
}
DataFeed::finish_start_ = true;
return true;
}
template <typename T>
int InMemoryDataFeed<T>::Next() {
DataFeed::CheckStart();
std::shared_ptr<paddle::framework::BlockingQueue<T>> in_channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> out_channel = nullptr;
if (cur_channel_ == 0) {
in_channel = shuffled_ins_;
out_channel = shuffled_ins_out_;
} else {
in_channel = shuffled_ins_out_;
out_channel = shuffled_ins_;
}
CHECK(in_channel != nullptr);
CHECK(out_channel != nullptr);
int index = 0;
T instance;
T ins_vec;
while (index < DataFeed::default_batch_size_) {
if (in_channel->Size() == 0) {
break;
}
in_channel->Pop(instance);
AddInstanceToInsVec(&ins_vec, instance, index++);
out_channel->Push(std::move(instance));
}
DataFeed::batch_size_ = index;
if (DataFeed::batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
cur_channel_ = 1 - cur_channel_;
}
return DataFeed::batch_size_;
}
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins;
DeserializeIns(ins, ins_str);
shuffled_ins_->Push(std::move(ins));
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
std::vector<T> local_vec;
std::string filename;
while (DataFeed::PickOneFile(&filename)) {
int err_no = 0;
PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
__fsetlocking(&*PrivateQueueDataFeed<T>::fp_, FSETLOCKING_BYCALLER);
T instance;
while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance);
}
memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end());
std::vector<T>().swap(local_vec);
}
}
template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
}
// todo global shuffle
/*
template <typename T>
......
......@@ -63,6 +63,7 @@ class PullDenseWorker {
static std::shared_ptr<PullDenseWorker> s_instance_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
PullDenseWorkerParameter param_;
DownpourWorkerParameter dwp_param_;
Scope* root_scope_;
bool running_;
......
......@@ -69,10 +69,16 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
auto table = param_.sparse_table(table_idx);
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(table_idx).table_id());
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == table_id) {
table = i;
break;
}
}
auto& feature = features_[table_id];
auto& feature_label = feature_labels_[table_id];
feature_label.resize(feature.size());
......@@ -103,10 +109,17 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
}
void DownpourWorker::FillSparseValue(size_t table_idx) {
auto table = param_.sparse_table(table_idx);
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == table_id) {
table = i;
break;
}
}
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(table_idx).table_id());
auto& fea_value = feature_values_[table_id];
auto fea_idx = 0u;
......@@ -147,11 +160,20 @@ void DownpourWorker::TrainFiles() {
int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) {
// pull sparse here
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.sparse_table(i).table_id());
fleet_ptr_->PullSparseVarsSync(
*thread_scope_, tid, sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], param_.sparse_table(i).fea_dim());
for (size_t i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid,
sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], table.fea_dim());
CollectLabelInfo(i);
FillSparseValue(i);
}
......@@ -172,17 +194,27 @@ void DownpourWorker::TrainFiles() {
}
// push gradients here
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.sparse_table(i).table_id());
for (size_t i = 0; i < param_.program_config(0).push_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid],
param_.sparse_table(i).emb_dim(), &feature_grads_[tid],
&push_sparse_status_);
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
}
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
......@@ -219,8 +251,10 @@ void DownpourWorker::TrainFiles() {
push_sparse_status_.resize(0);
}
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
......
......@@ -28,16 +28,26 @@ std::map<uint64_t, std::vector<std::string>>
void PullDenseWorker::Initialize(const TrainerDesc& param) {
running_ = false;
param_ = param.pull_dense_param();
dwp_param_ = param.downpour_param();
threshold_ = param_.threshold();
thread_num_ = param_.device_num();
sleep_time_ms_ = param_.sleep_time_ms();
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
for (size_t i = 0;
i < dwp_param_.program_config(0).pull_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
dwp_param_.program_config(0).pull_dense_table_id(i));
TableParameter table;
for (auto i : param_.dense_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
// setup dense variables for each table
int var_num = param_.dense_table(i).dense_value_name_size();
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
int var_num = table.dense_value_name_size();
dense_value_names_[tid].resize(var_num);
for (int j = 0; j < var_num; ++j) {
dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j);
dense_value_names_[tid][j] = table.dense_value_name(j);
}
// setup training version for each table
training_versions_[tid].resize(thread_num_, 0);
......@@ -82,8 +92,10 @@ int PullDenseWorker::Start() {
void PullDenseWorker::Run() {
while (running_) {
pull_dense_status_.resize(0);
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
for (size_t i = 0;
i < dwp_param_.program_config(0).pull_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
dwp_param_.program_config(0).pull_dense_table_id(i));
if (CheckUpdateParam(tid)) {
fleet_ptr_->PullDenseVarsAsync(
*root_scope_, tid, dense_value_names_[tid], &pull_dense_status_);
......
......@@ -45,6 +45,15 @@ message DownpourWorkerParameter {
repeated TableParameter sparse_table = 1;
repeated TableParameter dense_table = 2;
repeated string skip_ops = 3;
repeated ProgramConfig program_config = 4;
}
message ProgramConfig {
required string program_id = 1;
repeated int32 push_sparse_table_id = 2;
repeated int32 push_dense_table_id = 3;
repeated int32 pull_sparse_table_id = 4;
repeated int32 pull_dense_table_id = 5;
}
message PullDenseWorkerParameter {
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
......@@ -43,7 +41,7 @@ namespace paddle {
namespace pybind {
void BindDataset(py::module* m) {
py::class_<framework::DataSet>(*m, "Dataset")
py::class_<framework::Dataset>(*m, "Dataset")
.def(py::init([]() {
return std::unique_ptr<framework::Dataset>(new framework::Dataset());
}))
......@@ -53,7 +51,7 @@ void BindDataset(py::module* m) {
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GLobalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle);
}
} // end namespace pybind
......
......@@ -118,12 +118,13 @@ class AsyncExecutor(object):
trainer.set_thread(thread_num)
trainer.set_filelist(filelist)
trainer.set_data_feed(data_feed)
if not is_local:
trainer.set_program_config(self.dist_desc, str(id(program)))
with open("trainer_desc.proto", "w") as fout:
fout.write(trainer._desc())
# define a trainer and a device_worker here
self.executor.run_from_files(program_desc,
trainer._desc(), debug,
str(id(program_desc)))
trainer._desc(), debug)
'''
def run(self,
......
......@@ -78,3 +78,18 @@ class DistMultiTrainer(TrainerDesc):
worker_builder = DeviceWorkerFactory()
device_worker = worker_builder.create_device_worker("Downpour")
device_worker.gen_worker_desc(self.proto_desc, fleet_desc)
def set_program_config(self, fleet_desc, program_id):
for program_config in fleet_desc.trainer_param.program_config:
if program_config.program_id == program_id:
pc = self.proto_desc.downpour_param.program_config.add()
pc.program_id = program_config.program_id
for i in program_config.push_sparse_table_id:
pc.push_sparse_table_id.extend([i])
for i in program_config.push_dense_table_id:
pc.push_dense_table_id.extend([i])
for i in program_config.pull_sparse_table_id:
pc.pull_sparse_table_id.extend([i])
for i in program_config.pull_dense_table_id:
pc.pull_dense_table_id.extend([i])
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册