提交 6bf796df 编写于 作者: D dongdaxiang

refine print fetch list

上级 1ec8fab7
...@@ -311,6 +311,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -311,6 +311,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
} }
PrintFetchVars();
thread_scope_->DropKids(); thread_scope_->DropKids();
total_inst += cur_batch; total_inst += cur_batch;
++batch_cnt; ++batch_cnt;
...@@ -328,7 +329,6 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -328,7 +329,6 @@ void DownpourWorker::TrainFilesWithProfiler() {
} }
} }
timeline.Start(); timeline.Start();
PrintFetchVars();
} }
} }
...@@ -438,9 +438,9 @@ void DownpourWorker::TrainFiles() { ...@@ -438,9 +438,9 @@ void DownpourWorker::TrainFiles() {
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
} }
PrintFetchVars();
thread_scope_->DropKids(); thread_scope_->DropKids();
++batch_cnt; ++batch_cnt;
PrintFetchVars();
} }
} }
......
...@@ -102,7 +102,7 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -102,7 +102,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
} }
total_inst += cur_batch; total_inst += cur_batch;
++batch_cnt; ++batch_cnt;
thread_scope_->DropKids(); PrintFetchVars();
if (thread_id_ == 0) { if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % 100 == 0) { if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < ops_.size(); ++i) { for (size_t i = 0; i < ops_.size(); ++i) {
...@@ -114,8 +114,8 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -114,8 +114,8 @@ void HogwildWorker::TrainFilesWithProfiler() {
fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time); fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
} }
} }
thread_scope_->DropKids();
timeline.Start(); timeline.Start();
PrintFetchVars();
} }
} }
...@@ -125,15 +125,13 @@ void HogwildWorker::TrainFiles() { ...@@ -125,15 +125,13 @@ void HogwildWorker::TrainFiles() {
// how to accumulate fetched values here // how to accumulate fetched values here
device_reader_->Start(); device_reader_->Start();
int cur_batch; int cur_batch;
int batch_cnt = 0;
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
} }
++batch_cnt;
thread_scope_->DropKids();
PrintFetchVars(); PrintFetchVars();
thread_scope_->DropKids();
} }
} }
...@@ -146,7 +144,7 @@ void HogwildWorker::PrintFetchVars() { ...@@ -146,7 +144,7 @@ void HogwildWorker::PrintFetchVars() {
int fetch_var_num = fetch_config_.fetch_var_names_size(); int fetch_var_num = fetch_config_.fetch_var_names_size();
for (int i = 0; i < fetch_var_num; ++i) { for (int i = 0; i < fetch_var_num; ++i) {
platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i), platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
"None"); fetch_config_.fetch_var_str_format(i));
} }
} }
} }
......
...@@ -38,6 +38,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -38,6 +38,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetDeviceIndex(i); workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]); workers_[i]->SetDataFeed(readers[i]);
} }
......
...@@ -51,7 +51,7 @@ message DownpourWorkerParameter { ...@@ -51,7 +51,7 @@ message DownpourWorkerParameter {
message FetchConfig { message FetchConfig {
enum Method { PRINT = 0; } enum Method { PRINT = 0; }
repeated string fetch_var_names = 1; repeated string fetch_var_names = 1;
optional string fetch_var_str_format = 2; repeated string fetch_var_str_format = 2;
optional int32 print_period = 3 [ default = 100 ]; optional int32 print_period = 3 [ default = 100 ];
optional Method method = 4 [ default = PRINT ]; optional Method method = 4 [ default = PRINT ];
} }
......
...@@ -27,14 +27,12 @@ void print_lod_tensor(const std::string& var_name, ...@@ -27,14 +27,12 @@ void print_lod_tensor(const std::string& var_name,
auto element_num = lod_tensor.numel(); auto element_num = lod_tensor.numel();
std::ostringstream sstream; std::ostringstream sstream;
sstream << "user info: " << print_info << "\t"; sstream << print_info << "\t";
sstream << "var name: " << var_name << "\t"; sstream << var_name << "\t";
sstream << "numel: " << element_num << "\t"; sstream << inspect[0];
sstream << "value: " << inspect[0];
for (int j = 1; j < element_num; ++j) { for (int j = 1; j < element_num; ++j) {
sstream << " " << inspect[j]; sstream << " " << inspect[j];
} }
sstream << "]";
std::cout << sstream.str() << std::endl; std::cout << sstream.str() << std::endl;
} }
......
...@@ -634,6 +634,9 @@ class Executor(object): ...@@ -634,6 +634,9 @@ class Executor(object):
scope = global_scope() scope = global_scope()
if fetch_list is None: if fetch_list is None:
fetch_list = [] fetch_list = []
if fetch_info is None:
fetch_info = []
assert len(fetch_list) == len(fetch_info)
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
if not compiled: if not compiled:
trainer = TrainerFactory().create_trainer(program._fleet_opt) trainer = TrainerFactory().create_trainer(program._fleet_opt)
......
...@@ -37,10 +37,11 @@ class TrainerDesc(object): ...@@ -37,10 +37,11 @@ class TrainerDesc(object):
self.program_ = None self.program_ = None
def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period): def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for v in fetch_vars: for i, v in enumerate(fetch_vars):
self.proto_desc.fetch_config.fetch_var_names.extend(v.name) self.proto_desc.fetch_config.fetch_var_names.extend([v.name])
self.proto_desc.fetch_config.fetch_var_str_format = fetch_info self.proto_desc.fetch_config.fetch_var_str_format.extend(
self.proto_desc.print_period = print_period [fetch_info[i]])
self.proto_desc.fetch_config.print_period = print_period
def set_debug(self, debug): def set_debug(self, debug):
self.proto_desc.debug = debug self.proto_desc.debug = debug
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册