diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 310a6e2beb52a7a89f33f2f3fc73bd671b3a448d..9a3c5c51b5b509b3df82f2f87e51aec2dc6f5405 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -96,7 +96,7 @@ class DeviceWorker { virtual void Initialize(const TrainerDesc& desc) = 0; virtual void SetDeviceIndex(int tid) = 0; virtual void TrainFiles() = 0; - virtual void PrintFetchVars(int batch_cnt) = 0; + virtual void PrintFetchVars() = 0; virtual void TrainFilesWithProfiler() = 0; virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0; // will make this zero copy in the future @@ -111,6 +111,8 @@ class DeviceWorker { Scope* root_scope_; paddle::platform::Place place_; std::shared_ptr device_reader_; + int64_t batch_num_; + FetchConfig fetch_config_; }; class CPUWorkerBase : public DeviceWorker { @@ -120,7 +122,7 @@ class CPUWorkerBase : public DeviceWorker { virtual void SetDeviceIndex(int tid) { thread_id_ = tid; } virtual void TrainFiles() = 0; virtual void TrainFilesWithProfiler() {} - virtual void PrintFetchVars(int batch_cnt) {} + virtual void PrintFetchVars() {} virtual void CreateDeviceResource(const ProgramDesc& main_prog) {} protected: @@ -134,7 +136,7 @@ class HogwildWorker : public CPUWorkerBase { virtual void Initialize(const TrainerDesc& desc); virtual void TrainFiles(); virtual void TrainFilesWithProfiler(); - virtual void PrintFetchVars(int batch_cnt); + virtual void PrintFetchVars(); virtual void CreateDeviceResource(const ProgramDesc& main_prog); virtual void BindingDataFeedMemory(); @@ -144,9 +146,6 @@ class HogwildWorker : public CPUWorkerBase { std::vector op_names_; std::vector ops_; Scope* thread_scope_; - std::vector fetch_var_names_; - std::vector> fetch_values_; - int batch_cnt_per_print_; }; class DownpourWorker : public HogwildWorker { diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index 36282e5be7ebfa4ea3874fca1a773aea95172299..6b2852adc7a08072da00eb4258701f31f36be998 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -58,14 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { skip_ops_[i] = param_.skip_ops(i); } - fetch_var_names_.resize(desc.fetch_var_names_size()); - for (size_t i = 0; i < desc.fetch_var_names_size(); ++i) { - fetch_var_names_[i] = desc.fetch_var_names(i); - } - - batch_cnt_per_print_ = static_cast(desc.batch_per_print()); - skip_ops_.resize(param_.skip_ops_size()); fleet_ptr_ = FleetWrapper::GetInstance(); + fetch_config_ = desc.fetch_config(); } void DownpourWorker::CollectLabelInfo(size_t table_idx) { @@ -334,6 +328,7 @@ void DownpourWorker::TrainFilesWithProfiler() { } } timeline.Start(); + PrintFetchVars(); } } @@ -445,6 +440,7 @@ void DownpourWorker::TrainFiles() { thread_scope_->DropKids(); ++batch_cnt; + PrintFetchVars(); } } diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index 64f2e75a20a3912960e92cfdba5c0fa4d82d2367..d4e3d24921907c7e6dc9bdced322eeb9e37cef94 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -21,11 +21,7 @@ namespace paddle { namespace framework { void HogwildWorker::Initialize(const TrainerDesc& desc) { - fetch_var_names_.resize(desc.fetch_var_names_size()); - for (size_t i = 0; i < desc.fetch_var_names_size(); ++i) { - fetch_var_names_[i] = desc.fetch_var_names(i); - } - batch_cnt_per_print_ = static_cast(desc.batch_per_print()); + fetch_config_ = desc.fetch_config(); } void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) { @@ -119,6 +115,7 @@ void HogwildWorker::TrainFilesWithProfiler() { } } timeline.Start(); + PrintFetchVars(); } } @@ -136,15 +133,20 @@ void HogwildWorker::TrainFiles() { ++batch_cnt; thread_scope_->DropKids(); + PrintFetchVars(); } } -void HogwildWorker::PrintFetchVars(int batch_cnt) { +void HogwildWorker::PrintFetchVars() { + // call count + batch_num_++; + int batch_per_print = fetch_config_.print_period(); if (thread_id_ == 0) { - if (batch_cnt > 0 && batch_cnt % batch_cnt_per_print_ == 0) { - int fetch_var_num = fetch_var_names_.size(); + if (batch_num_ % batch_per_print == 0) { + int fetch_var_num = fetch_config_.fetch_var_names_size(); for (int i = 0; i < fetch_var_num; ++i) { - platform::PrintVar(thread_scope_, fetch_var_names_[i], "None"); + platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i), + "None"); } } } diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index f422d226ca525f7b50e00f85928b212384aa152e..4941ea0f8fc3f2ec7bb6f832353bf0ca397a6eb8 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -28,9 +28,8 @@ message TrainerDesc { // if we need to binding cpu optional bool binding_cpu = 4 [ default = false ]; repeated string filelist = 5; - repeated string fetch_var_names = 6; - optional int32 batch_per_print = 7 [ default = 100 ]; - optional bool debug = 8 [ default = false ]; + optional bool debug = 6 [ default = false ]; + optional FetchConfig fetch_config = 7; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; @@ -49,6 +48,14 @@ message DownpourWorkerParameter { repeated ProgramConfig program_config = 4; } +message FetchConfig { + enum Method { PRINT = 0; } + repeated string fetch_var_names = 1; + optional string fetch_var_str_format = 2; + optional int32 print_period = 3 [ default = 100 ]; + optional Method method = 4 [ default = PRINT ]; +} + message ProgramConfig { required string program_id = 1; repeated int32 push_sparse_table_id = 2; diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index bf4edf5be20d8d0bc35ddeee17c46e1d6fd625b2..ed3907e5a00a1df3b2e82649e2a5848fafc065fa 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -621,13 +621,17 @@ class Executor(object): opt_info=None): pass + fluid.Logger("Loss: {0}", loss) + def train_from_dataset(self, program=None, dataset=None, - fetch_list=None, scope=None, thread=0, - debug=False): + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100): if scope is None: scope = global_scope() if fetch_list is None: @@ -650,6 +654,7 @@ class Executor(object): else: trainer.set_thread(thread) trainer.set_debug(debug) + trainer.set_fetch_var_and_info(fetch_list, fetch_info, print_period) trainer.gen_trainer_desc() dataset._prepare_to_run() if debug: diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 8bc739707b35b7a6d238ffe085456d3487ff279d..97d3298fa12ed13384ab98db7d5fd2689f69d70f 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -36,6 +36,12 @@ class TrainerDesc(object): self.device_worker_ = None self.program_ = None + def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period): + for v in fetch_vars: + self.proto_desc.fetch_config.fetch_var_names.extend(v.name) + self.proto_desc.fetch_config.fetch_var_str_format = fetch_info + self.proto_desc.print_period = print_period + def set_debug(self, debug): self.proto_desc.debug = debug