From 6bf796df14ef6194c16d321f9da5401aa6f7cf2c Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Thu, 21 Mar 2019 16:27:51 +0800 Subject: [PATCH] refine print fetch list --- paddle/fluid/framework/downpour_worker.cc | 4 ++-- paddle/fluid/framework/hogwild_worker.cc | 10 ++++------ paddle/fluid/framework/multi_trainer.cc | 1 + paddle/fluid/framework/trainer_desc.proto | 2 +- paddle/fluid/platform/lodtensor_printer.cc | 8 +++----- python/paddle/fluid/executor.py | 3 +++ python/paddle/fluid/trainer_desc.py | 9 +++++---- 7 files changed, 19 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index 6b2852adc7a..e64d0c77d7a 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -311,6 +311,7 @@ void DownpourWorker::TrainFilesWithProfiler() { pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); } + PrintFetchVars(); thread_scope_->DropKids(); total_inst += cur_batch; ++batch_cnt; @@ -328,7 +329,6 @@ void DownpourWorker::TrainFilesWithProfiler() { } } timeline.Start(); - PrintFetchVars(); } } @@ -438,9 +438,9 @@ void DownpourWorker::TrainFiles() { pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); } + PrintFetchVars(); thread_scope_->DropKids(); ++batch_cnt; - PrintFetchVars(); } } diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index d4e3d249219..1f5389c9c5e 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -102,7 +102,7 @@ void HogwildWorker::TrainFilesWithProfiler() { } total_inst += cur_batch; ++batch_cnt; - thread_scope_->DropKids(); + PrintFetchVars(); if (thread_id_ == 0) { if (batch_cnt > 0 && batch_cnt % 100 == 0) { for (size_t i = 0; i < ops_.size(); ++i) { @@ -114,8 +114,8 @@ void HogwildWorker::TrainFilesWithProfiler() { fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time); } } + thread_scope_->DropKids(); timeline.Start(); - PrintFetchVars(); } } @@ -125,15 +125,13 @@ void HogwildWorker::TrainFiles() { // how to accumulate fetched values here device_reader_->Start(); int cur_batch; - int batch_cnt = 0; while ((cur_batch = device_reader_->Next()) > 0) { for (auto& op : ops_) { op->Run(*thread_scope_, place_); } - ++batch_cnt; - thread_scope_->DropKids(); PrintFetchVars(); + thread_scope_->DropKids(); } } @@ -146,7 +144,7 @@ void HogwildWorker::PrintFetchVars() { int fetch_var_num = fetch_config_.fetch_var_names_size(); for (int i = 0; i < fetch_var_num; ++i) { platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i), - "None"); + fetch_config_.fetch_var_str_format(i)); } } } diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 7f955e35506..409c2f435f8 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -38,6 +38,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); + workers_[i]->Initialize(trainer_desc); workers_[i]->SetDeviceIndex(i); workers_[i]->SetDataFeed(readers[i]); } diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 4941ea0f8fc..6acadfb2da4 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -51,7 +51,7 @@ message DownpourWorkerParameter { message FetchConfig { enum Method { PRINT = 0; } repeated string fetch_var_names = 1; - optional string fetch_var_str_format = 2; + repeated string fetch_var_str_format = 2; optional int32 print_period = 3 [ default = 100 ]; optional Method method = 4 [ default = PRINT ]; } diff --git a/paddle/fluid/platform/lodtensor_printer.cc b/paddle/fluid/platform/lodtensor_printer.cc index 5bfbcdeecfb..b9ab19a154f 100644 --- a/paddle/fluid/platform/lodtensor_printer.cc +++ b/paddle/fluid/platform/lodtensor_printer.cc @@ -27,14 +27,12 @@ void print_lod_tensor(const std::string& var_name, auto element_num = lod_tensor.numel(); std::ostringstream sstream; - sstream << "user info: " << print_info << "\t"; - sstream << "var name: " << var_name << "\t"; - sstream << "numel: " << element_num << "\t"; - sstream << "value: " << inspect[0]; + sstream << print_info << "\t"; + sstream << var_name << "\t"; + sstream << inspect[0]; for (int j = 1; j < element_num; ++j) { sstream << " " << inspect[j]; } - sstream << "]"; std::cout << sstream.str() << std::endl; } diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 8c3f947b6bd..d7e125f4847 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -634,6 +634,9 @@ class Executor(object): scope = global_scope() if fetch_list is None: fetch_list = [] + if fetch_info is None: + fetch_info = [] + assert len(fetch_list) == len(fetch_info) compiled = isinstance(program, compiler.CompiledProgram) if not compiled: trainer = TrainerFactory().create_trainer(program._fleet_opt) diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 97d3298fa12..4d61a09fb92 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -37,10 +37,11 @@ class TrainerDesc(object): self.program_ = None def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period): - for v in fetch_vars: - self.proto_desc.fetch_config.fetch_var_names.extend(v.name) - self.proto_desc.fetch_config.fetch_var_str_format = fetch_info - self.proto_desc.print_period = print_period + for i, v in enumerate(fetch_vars): + self.proto_desc.fetch_config.fetch_var_names.extend([v.name]) + self.proto_desc.fetch_config.fetch_var_str_format.extend( + [fetch_info[i]]) + self.proto_desc.fetch_config.print_period = print_period def set_debug(self, debug): self.proto_desc.debug = debug -- GitLab