提交 68d7bf3d 编写于 作者: D dongdaxiang

add fetch var function

test=develop
上级 767bf0c8
......@@ -96,7 +96,7 @@ class DeviceWorker {
virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0;
virtual void TrainFiles() = 0;
virtual void PrintFetchVars(int batch_cnt) = 0;
virtual void PrintFetchVars() = 0;
virtual void TrainFilesWithProfiler() = 0;
virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
// will make this zero copy in the future
......@@ -111,6 +111,8 @@ class DeviceWorker {
Scope* root_scope_;
paddle::platform::Place place_;
std::shared_ptr<DataFeed> device_reader_;
int64_t batch_num_;
FetchConfig fetch_config_;
};
class CPUWorkerBase : public DeviceWorker {
......@@ -120,7 +122,7 @@ class CPUWorkerBase : public DeviceWorker {
virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
virtual void TrainFiles() = 0;
virtual void TrainFilesWithProfiler() {}
virtual void PrintFetchVars(int batch_cnt) {}
virtual void PrintFetchVars() {}
virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}
protected:
......@@ -134,7 +136,7 @@ class HogwildWorker : public CPUWorkerBase {
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
virtual void PrintFetchVars(int batch_cnt);
virtual void PrintFetchVars();
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory();
......@@ -144,9 +146,6 @@ class HogwildWorker : public CPUWorkerBase {
std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_;
Scope* thread_scope_;
std::vector<std::string> fetch_var_names_;
std::vector<std::vector<float>> fetch_values_;
int batch_cnt_per_print_;
};
class DownpourWorker : public HogwildWorker {
......
......@@ -58,14 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
skip_ops_[i] = param_.skip_ops(i);
}
fetch_var_names_.resize(desc.fetch_var_names_size());
for (size_t i = 0; i < desc.fetch_var_names_size(); ++i) {
fetch_var_names_[i] = desc.fetch_var_names(i);
}
batch_cnt_per_print_ = static_cast<int>(desc.batch_per_print());
skip_ops_.resize(param_.skip_ops_size());
fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config();
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
......@@ -334,6 +328,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
timeline.Start();
PrintFetchVars();
}
}
......@@ -445,6 +440,7 @@ void DownpourWorker::TrainFiles() {
thread_scope_->DropKids();
++batch_cnt;
PrintFetchVars();
}
}
......
......@@ -21,11 +21,7 @@ namespace paddle {
namespace framework {
void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_var_names_.resize(desc.fetch_var_names_size());
for (size_t i = 0; i < desc.fetch_var_names_size(); ++i) {
fetch_var_names_[i] = desc.fetch_var_names(i);
}
batch_cnt_per_print_ = static_cast<int>(desc.batch_per_print());
fetch_config_ = desc.fetch_config();
}
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
......@@ -119,6 +115,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
}
}
timeline.Start();
PrintFetchVars();
}
}
......@@ -136,15 +133,20 @@ void HogwildWorker::TrainFiles() {
++batch_cnt;
thread_scope_->DropKids();
PrintFetchVars();
}
}
void HogwildWorker::PrintFetchVars(int batch_cnt) {
void HogwildWorker::PrintFetchVars() {
// call count
batch_num_++;
int batch_per_print = fetch_config_.print_period();
if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % batch_cnt_per_print_ == 0) {
int fetch_var_num = fetch_var_names_.size();
if (batch_num_ % batch_per_print == 0) {
int fetch_var_num = fetch_config_.fetch_var_names_size();
for (int i = 0; i < fetch_var_num; ++i) {
platform::PrintVar(thread_scope_, fetch_var_names_[i], "None");
platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
"None");
}
}
}
......
......@@ -28,9 +28,8 @@ message TrainerDesc {
// if we need to binding cpu
optional bool binding_cpu = 4 [ default = false ];
repeated string filelist = 5;
repeated string fetch_var_names = 6;
optional int32 batch_per_print = 7 [ default = 100 ];
optional bool debug = 8 [ default = false ];
optional bool debug = 6 [ default = false ];
optional FetchConfig fetch_config = 7;
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
......@@ -49,6 +48,14 @@ message DownpourWorkerParameter {
repeated ProgramConfig program_config = 4;
}
message FetchConfig {
enum Method { PRINT = 0; }
repeated string fetch_var_names = 1;
optional string fetch_var_str_format = 2;
optional int32 print_period = 3 [ default = 100 ];
optional Method method = 4 [ default = PRINT ];
}
message ProgramConfig {
required string program_id = 1;
repeated int32 push_sparse_table_id = 2;
......
......@@ -621,13 +621,17 @@ class Executor(object):
opt_info=None):
pass
fluid.Logger("Loss: {0}", loss)
def train_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
debug=False):
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
if scope is None:
scope = global_scope()
if fetch_list is None:
......@@ -650,6 +654,7 @@ class Executor(object):
else:
trainer.set_thread(thread)
trainer.set_debug(debug)
trainer.set_fetch_var_and_info(fetch_list, fetch_info, print_period)
trainer.gen_trainer_desc()
dataset._prepare_to_run()
if debug:
......
......@@ -36,6 +36,12 @@ class TrainerDesc(object):
self.device_worker_ = None
self.program_ = None
def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for v in fetch_vars:
self.proto_desc.fetch_config.fetch_var_names.extend(v.name)
self.proto_desc.fetch_config.fetch_var_str_format = fetch_info
self.proto_desc.print_period = print_period
def set_debug(self, debug):
self.proto_desc.debug = debug
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册