提交 6bf796df 编写于 作者: D dongdaxiang

refine print fetch list

上级 1ec8fab7
......@@ -311,6 +311,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
PrintFetchVars();
thread_scope_->DropKids();
total_inst += cur_batch;
++batch_cnt;
......@@ -328,7 +329,6 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
timeline.Start();
PrintFetchVars();
}
}
......@@ -438,9 +438,9 @@ void DownpourWorker::TrainFiles() {
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
PrintFetchVars();
thread_scope_->DropKids();
++batch_cnt;
PrintFetchVars();
}
}
......
......@@ -102,7 +102,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
}
total_inst += cur_batch;
++batch_cnt;
thread_scope_->DropKids();
PrintFetchVars();
if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < ops_.size(); ++i) {
......@@ -114,8 +114,8 @@ void HogwildWorker::TrainFilesWithProfiler() {
fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
}
}
thread_scope_->DropKids();
timeline.Start();
PrintFetchVars();
}
}
......@@ -125,15 +125,13 @@ void HogwildWorker::TrainFiles() {
// how to accumulate fetched values here
device_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
++batch_cnt;
thread_scope_->DropKids();
PrintFetchVars();
thread_scope_->DropKids();
}
}
......@@ -146,7 +144,7 @@ void HogwildWorker::PrintFetchVars() {
int fetch_var_num = fetch_config_.fetch_var_names_size();
for (int i = 0; i < fetch_var_num; ++i) {
platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
"None");
fetch_config_.fetch_var_str_format(i));
}
}
}
......
......@@ -38,6 +38,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]);
}
......
......@@ -51,7 +51,7 @@ message DownpourWorkerParameter {
message FetchConfig {
enum Method { PRINT = 0; }
repeated string fetch_var_names = 1;
optional string fetch_var_str_format = 2;
repeated string fetch_var_str_format = 2;
optional int32 print_period = 3 [ default = 100 ];
optional Method method = 4 [ default = PRINT ];
}
......
......@@ -27,14 +27,12 @@ void print_lod_tensor(const std::string& var_name,
auto element_num = lod_tensor.numel();
std::ostringstream sstream;
sstream << "user info: " << print_info << "\t";
sstream << "var name: " << var_name << "\t";
sstream << "numel: " << element_num << "\t";
sstream << "value: " << inspect[0];
sstream << print_info << "\t";
sstream << var_name << "\t";
sstream << inspect[0];
for (int j = 1; j < element_num; ++j) {
sstream << " " << inspect[j];
}
sstream << "]";
std::cout << sstream.str() << std::endl;
}
......
......@@ -634,6 +634,9 @@ class Executor(object):
scope = global_scope()
if fetch_list is None:
fetch_list = []
if fetch_info is None:
fetch_info = []
assert len(fetch_list) == len(fetch_info)
compiled = isinstance(program, compiler.CompiledProgram)
if not compiled:
trainer = TrainerFactory().create_trainer(program._fleet_opt)
......
......@@ -37,10 +37,11 @@ class TrainerDesc(object):
self.program_ = None
def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for v in fetch_vars:
self.proto_desc.fetch_config.fetch_var_names.extend(v.name)
self.proto_desc.fetch_config.fetch_var_str_format = fetch_info
self.proto_desc.print_period = print_period
for i, v in enumerate(fetch_vars):
self.proto_desc.fetch_config.fetch_var_names.extend([v.name])
self.proto_desc.fetch_config.fetch_var_str_format.extend(
[fetch_info[i]])
self.proto_desc.fetch_config.print_period = print_period
def set_debug(self, debug):
self.proto_desc.debug = debug
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册