提交 68d7bf3d 编写于 作者: D dongdaxiang

add fetch var function

test=develop
上级 767bf0c8
...@@ -96,7 +96,7 @@ class DeviceWorker { ...@@ -96,7 +96,7 @@ class DeviceWorker {
virtual void Initialize(const TrainerDesc& desc) = 0; virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0; virtual void SetDeviceIndex(int tid) = 0;
virtual void TrainFiles() = 0; virtual void TrainFiles() = 0;
virtual void PrintFetchVars(int batch_cnt) = 0; virtual void PrintFetchVars() = 0;
virtual void TrainFilesWithProfiler() = 0; virtual void TrainFilesWithProfiler() = 0;
virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0; virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
// will make this zero copy in the future // will make this zero copy in the future
...@@ -111,6 +111,8 @@ class DeviceWorker { ...@@ -111,6 +111,8 @@ class DeviceWorker {
Scope* root_scope_; Scope* root_scope_;
paddle::platform::Place place_; paddle::platform::Place place_;
std::shared_ptr<DataFeed> device_reader_; std::shared_ptr<DataFeed> device_reader_;
int64_t batch_num_;
FetchConfig fetch_config_;
}; };
class CPUWorkerBase : public DeviceWorker { class CPUWorkerBase : public DeviceWorker {
...@@ -120,7 +122,7 @@ class CPUWorkerBase : public DeviceWorker { ...@@ -120,7 +122,7 @@ class CPUWorkerBase : public DeviceWorker {
virtual void SetDeviceIndex(int tid) { thread_id_ = tid; } virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
virtual void TrainFiles() = 0; virtual void TrainFiles() = 0;
virtual void TrainFilesWithProfiler() {} virtual void TrainFilesWithProfiler() {}
virtual void PrintFetchVars(int batch_cnt) {} virtual void PrintFetchVars() {}
virtual void CreateDeviceResource(const ProgramDesc& main_prog) {} virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}
protected: protected:
...@@ -134,7 +136,7 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -134,7 +136,7 @@ class HogwildWorker : public CPUWorkerBase {
virtual void Initialize(const TrainerDesc& desc); virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles(); virtual void TrainFiles();
virtual void TrainFilesWithProfiler(); virtual void TrainFilesWithProfiler();
virtual void PrintFetchVars(int batch_cnt); virtual void PrintFetchVars();
virtual void CreateDeviceResource(const ProgramDesc& main_prog); virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory(); virtual void BindingDataFeedMemory();
...@@ -144,9 +146,6 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -144,9 +146,6 @@ class HogwildWorker : public CPUWorkerBase {
std::vector<std::string> op_names_; std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_; std::vector<OperatorBase*> ops_;
Scope* thread_scope_; Scope* thread_scope_;
std::vector<std::string> fetch_var_names_;
std::vector<std::vector<float>> fetch_values_;
int batch_cnt_per_print_;
}; };
class DownpourWorker : public HogwildWorker { class DownpourWorker : public HogwildWorker {
......
...@@ -58,14 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -58,14 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
skip_ops_[i] = param_.skip_ops(i); skip_ops_[i] = param_.skip_ops(i);
} }
fetch_var_names_.resize(desc.fetch_var_names_size());
for (size_t i = 0; i < desc.fetch_var_names_size(); ++i) {
fetch_var_names_[i] = desc.fetch_var_names(i);
}
batch_cnt_per_print_ = static_cast<int>(desc.batch_per_print());
skip_ops_.resize(param_.skip_ops_size());
fleet_ptr_ = FleetWrapper::GetInstance(); fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config();
} }
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
...@@ -334,6 +328,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -334,6 +328,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
} }
} }
timeline.Start(); timeline.Start();
PrintFetchVars();
} }
} }
...@@ -445,6 +440,7 @@ void DownpourWorker::TrainFiles() { ...@@ -445,6 +440,7 @@ void DownpourWorker::TrainFiles() {
thread_scope_->DropKids(); thread_scope_->DropKids();
++batch_cnt; ++batch_cnt;
PrintFetchVars();
} }
} }
......
...@@ -21,11 +21,7 @@ namespace paddle { ...@@ -21,11 +21,7 @@ namespace paddle {
namespace framework { namespace framework {
void HogwildWorker::Initialize(const TrainerDesc& desc) { void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_var_names_.resize(desc.fetch_var_names_size()); fetch_config_ = desc.fetch_config();
for (size_t i = 0; i < desc.fetch_var_names_size(); ++i) {
fetch_var_names_[i] = desc.fetch_var_names(i);
}
batch_cnt_per_print_ = static_cast<int>(desc.batch_per_print());
} }
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) { void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
...@@ -119,6 +115,7 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -119,6 +115,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
} }
} }
timeline.Start(); timeline.Start();
PrintFetchVars();
} }
} }
...@@ -136,15 +133,20 @@ void HogwildWorker::TrainFiles() { ...@@ -136,15 +133,20 @@ void HogwildWorker::TrainFiles() {
++batch_cnt; ++batch_cnt;
thread_scope_->DropKids(); thread_scope_->DropKids();
PrintFetchVars();
} }
} }
void HogwildWorker::PrintFetchVars(int batch_cnt) { void HogwildWorker::PrintFetchVars() {
// call count
batch_num_++;
int batch_per_print = fetch_config_.print_period();
if (thread_id_ == 0) { if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % batch_cnt_per_print_ == 0) { if (batch_num_ % batch_per_print == 0) {
int fetch_var_num = fetch_var_names_.size(); int fetch_var_num = fetch_config_.fetch_var_names_size();
for (int i = 0; i < fetch_var_num; ++i) { for (int i = 0; i < fetch_var_num; ++i) {
platform::PrintVar(thread_scope_, fetch_var_names_[i], "None"); platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
"None");
} }
} }
} }
......
...@@ -28,9 +28,8 @@ message TrainerDesc { ...@@ -28,9 +28,8 @@ message TrainerDesc {
// if we need to binding cpu // if we need to binding cpu
optional bool binding_cpu = 4 [ default = false ]; optional bool binding_cpu = 4 [ default = false ];
repeated string filelist = 5; repeated string filelist = 5;
repeated string fetch_var_names = 6; optional bool debug = 6 [ default = false ];
optional int32 batch_per_print = 7 [ default = 100 ]; optional FetchConfig fetch_config = 7;
optional bool debug = 8 [ default = false ];
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
...@@ -49,6 +48,14 @@ message DownpourWorkerParameter { ...@@ -49,6 +48,14 @@ message DownpourWorkerParameter {
repeated ProgramConfig program_config = 4; repeated ProgramConfig program_config = 4;
} }
message FetchConfig {
enum Method { PRINT = 0; }
repeated string fetch_var_names = 1;
optional string fetch_var_str_format = 2;
optional int32 print_period = 3 [ default = 100 ];
optional Method method = 4 [ default = PRINT ];
}
message ProgramConfig { message ProgramConfig {
required string program_id = 1; required string program_id = 1;
repeated int32 push_sparse_table_id = 2; repeated int32 push_sparse_table_id = 2;
......
...@@ -621,13 +621,17 @@ class Executor(object): ...@@ -621,13 +621,17 @@ class Executor(object):
opt_info=None): opt_info=None):
pass pass
fluid.Logger("Loss: {0}", loss)
def train_from_dataset(self, def train_from_dataset(self,
program=None, program=None,
dataset=None, dataset=None,
fetch_list=None,
scope=None, scope=None,
thread=0, thread=0,
debug=False): debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
if fetch_list is None: if fetch_list is None:
...@@ -650,6 +654,7 @@ class Executor(object): ...@@ -650,6 +654,7 @@ class Executor(object):
else: else:
trainer.set_thread(thread) trainer.set_thread(thread)
trainer.set_debug(debug) trainer.set_debug(debug)
trainer.set_fetch_var_and_info(fetch_list, fetch_info, print_period)
trainer.gen_trainer_desc() trainer.gen_trainer_desc()
dataset._prepare_to_run() dataset._prepare_to_run()
if debug: if debug:
......
...@@ -36,6 +36,12 @@ class TrainerDesc(object): ...@@ -36,6 +36,12 @@ class TrainerDesc(object):
self.device_worker_ = None self.device_worker_ = None
self.program_ = None self.program_ = None
def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for v in fetch_vars:
self.proto_desc.fetch_config.fetch_var_names.extend(v.name)
self.proto_desc.fetch_config.fetch_var_str_format = fetch_info
self.proto_desc.print_period = print_period
def set_debug(self, debug): def set_debug(self, debug):
self.proto_desc.debug = debug self.proto_desc.debug = debug
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册