提交 2644b886 编写于 作者: D dongdaxiang

add comment for MPI Symetric role maker

test=develop
上级 ea5851fa
......@@ -155,6 +155,7 @@ class DownpourWorker : public HogwildWorker {
virtual ~DownpourWorker() {}
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
protected:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
......
......@@ -44,6 +44,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
VLOG(3) << "initialize pull dense worker";
SetDebug(trainer_desc.debug());
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
......
......@@ -70,7 +70,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
param_.program_config(0).pull_sparse_table_id(table_idx));
TableParameter table;
for (auto i : param_.sparse_table()) {
......@@ -82,16 +82,23 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
auto& feature = features_[table_id];
auto& feature_label = feature_labels_[table_id];
feature_label.resize(feature.size());
VLOG(3) << "going to get label_var_name " << label_var_name_[table_id];
Variable* var = thread_scope_->FindVar(label_var_name_[table_id]);
VLOG(3) << "going to get tensor";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
VLOG(3) << "going to get ptr";
int64_t* label_ptr = tensor->data<int64_t>();
VLOG(3) << "lele";
int global_index = 0;
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
VLOG(3) << "sparse_key_names_[" << i
<< "]: " << sparse_key_names_[table_id][i];
Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]);
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int fea_idx = 0;
VLOG(3) << "Haha";
// tensor->lod()[0].size() == batch_size + 1
for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) {
......@@ -103,6 +110,7 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
static_cast<float>(label_ptr[lod_idx - 1]);
}
}
VLOG(3) << "EE";
}
CHECK(global_index == feature.size())
<< "expect fea info size:" << feature.size() << " real:" << global_index;
......@@ -110,7 +118,7 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
void DownpourWorker::FillSparseValue(size_t table_idx) {
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
param_.program_config(0).pull_sparse_table_id(table_idx));
TableParameter table;
for (auto i : param_.sparse_table()) {
......@@ -152,6 +160,11 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
}
}
void DownpourWorker::TrainFilesWithProfiler() {
VLOG(3) << "Begin to train files with profiler";
platform::SetNumThreads(1);
}
void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files";
platform::SetNumThreads(1);
......
......@@ -41,6 +41,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
}
// set debug here
SetDebug(trainer_desc.debug());
}
// call only after all resources are set in current trainer
......@@ -57,8 +58,13 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
void MultiTrainer::Run() {
VLOG(3) << "Going to run";
for (int thidx = 0; thidx < thread_num_; ++thidx) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
workers_[thidx].get()));
}
}
}
......
......@@ -30,6 +30,7 @@ message TrainerDesc {
repeated string filelist = 5;
repeated string fetch_var_names = 6;
optional int32 batch_per_print = 7 [ default = 100 ];
optional bool debug = 8 [ default = false ];
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
......
......@@ -28,6 +28,7 @@ class Fleet(object):
def __init__(self):
self._opt_info = None # for fleet only
self.role_maker_ = None
self.local_ip_ = 0
def init(self):
# TODO(guru4elephant)
......@@ -57,9 +58,12 @@ class Fleet(object):
self._fleet_ptr.init_server(self._dist_desc_str,
self.role_maker_.get_rank())
self.local_ip_ = self._fleet_ptr.run_server()
self.role_maker_.barrier_all()
self.all_ips_ = self.role_maker_.all_gather(self.local_ip_)
self._fleet_ptr.gather_servers(self.all_ips_,
self.role_maker_.get_size())
# wait all workers start
self.role_maker_.barrier_all()
else:
print("You should run DistributedOptimizer.minimize() first")
......@@ -74,10 +78,12 @@ class Fleet(object):
else:
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
self.role_maker_.barrier_all()
self._fleet_ptr.init_worker(self._dist_desc_str, [0],
self.role_maker_.barrier_all() # wait for server starts
self.all_ips_ = self.role_maker_.all_gather(self.local_ip_)
self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_,
self.role_maker_.get_size(),
self.role_maker_.get_rank())
self.role_maker_.barrier_all()
self.role_maker_.barrier_worker()
else:
print("You should run DistributedOptimizer.minimize() first")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册