diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 361f25d09e70006e0dcba31977379a756ba8c96f..95a261be1a70131dd1282c2ce1f027c52a41c5ce 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -66,11 +66,9 @@ else() cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor) endif() cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version) -cc_library(device_worker SRCS device_worker.cc DEPS trainer_desc_proto lod_tensor) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) -cc_test(device_worker_test SRCS device_worker_test.cc DEPS device_worker) cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory gflags glog) @@ -87,6 +85,8 @@ endif() cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits) cc_library(scope SRCS scope.cc DEPS glog threadpool xxhash var_type_traits) +cc_library(device_worker SRCS device_worker.cc DEPS trainer_desc_proto lod_tensor scope) +cc_test(device_worker_test SRCS device_worker_test.cc DEPS device_worker) cc_library(scope_pool SRCS scope_pool.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 22c08eba664e5a6fab973248268050300c52b9b9..7ac023f140ecbd209e902ba67dd64bf8f5fef806 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -881,6 +881,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { uint32_t rank; GetMsgFromLogKey(log_key, &search_id, &cmatch, &rank); + instance->ins_id_ = log_key; instance->search_id = search_id; instance->cmatch = cmatch; instance->rank = rank; diff --git a/paddle/fluid/framework/device_worker.cc b/paddle/fluid/framework/device_worker.cc index 6ba596ab1592dddddf2111ba67578bd98a450056..e39ebf8a7d49e7537da723cf946419bca0c1dd9a 100644 --- a/paddle/fluid/framework/device_worker.cc +++ b/paddle/fluid/framework/device_worker.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/device_worker.h" +#include "xxhash.h" // NOLINT namespace paddle { namespace framework { @@ -91,5 +92,109 @@ bool CheckValidOutput(LoDTensor* tensor, size_t batch_size) { return true; } +void DeviceWorker::DumpParam(const Scope& scope, const int batch_id) { + std::ostringstream os; + for (auto& param : *dump_param_) { + os.str(""); + Variable* var = scope.FindVar(param); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + framework::LoDTensor cpu_tensor; + if (platform::is_gpu_place(tensor->place())) { + TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); + tensor = &cpu_tensor; + } + int64_t len = tensor->numel(); + os << "(" << batch_id << "," << param << ")" + << PrintLodTensor(tensor, 0, len); + writer_ << os.str(); + } +} +void DeviceWorker::InitRandomDumpConfig(const TrainerDesc& desc) { + bool enable_random_dump = desc.enable_random_dump(); + if (!enable_random_dump) { + dump_mode_ = 0; + } else { + if (desc.random_with_lineid()) { + dump_mode_ = 1; + } else { + dump_mode_ = 2; + } + } + dump_interval_ = desc.dump_interval(); +} + +void DeviceWorker::DumpField(const Scope& scope, int dump_mode, + int dump_interval) { // dump_mode: 0: no random, + // 1: random with insid hash, + // 2: random with random + // number + size_t batch_size = device_reader_->GetCurBatchSize(); + auto& ins_id_vec = device_reader_->GetInsIdVec(); + auto& ins_content_vec = device_reader_->GetInsContentVec(); + if (ins_id_vec.size() > 0) { + batch_size = ins_id_vec.size(); + } + std::vector ars(batch_size); + std::vector hit(batch_size, false); + + std::default_random_engine engine(0); + std::uniform_int_distribution dist(0U, INT_MAX); + for (size_t i = 0; i < batch_size; i++) { + size_t r = 0; + if (dump_mode == 1) { + r = XXH64(ins_id_vec[i].data(), ins_id_vec[i].length(), 0); + } else if (dump_mode == 2) { + r = dist(engine); + } + if (r % dump_interval != 0) { + continue; + } + hit[i] = true; + } + for (size_t i = 0; i < ins_id_vec.size(); i++) { + if (!hit[i]) { + continue; + } + ars[i] += ins_id_vec[i]; + ars[i] = ars[i] + "\t" + ins_content_vec[i]; + } + for (auto& field : *dump_fields_) { + Variable* var = scope.FindVar(field); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + framework::LoDTensor cpu_tensor; + if (platform::is_gpu_place(tensor->place())) { + TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); + tensor = &cpu_tensor; + } + if (!CheckValidOutput(tensor, batch_size)) { + continue; + } + for (size_t i = 0; i < batch_size; ++i) { + if (!hit[i]) { + continue; + } + auto output_dim = tensor->dims()[1]; + std::string output_dimstr = boost::lexical_cast(output_dim); + ars[i] = ars[i] + "\t" + field + ":" + output_dimstr; + auto bound = GetTensorBound(tensor, i); + + ars[i] += PrintLodTensor(tensor, bound.first, bound.second); + } + } + // #pragma omp parallel for + for (size_t i = 0; i < ars.size(); i++) { + if (ars[i].length() == 0) { + continue; + } + writer_ << ars[i]; + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index f75d7593fe9a50d385d0e669e5037fbc5d4eea5b..8d50f476eaeee410a608dfdd2ee05b836c10c8a0 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -120,6 +120,7 @@ class DeviceWorker { } virtual ~DeviceWorker() {} virtual void Initialize(const TrainerDesc& desc) = 0; + virtual void InitRandomDumpConfig(const TrainerDesc& desc); virtual void SetDeviceIndex(int tid) = 0; virtual void TrainFiles() = 0; virtual void PrintFetchVars() = 0; @@ -129,8 +130,21 @@ class DeviceWorker { virtual void BindingDataFeedMemory() = 0; virtual void SetRootScope(Scope* root_scope); virtual void SetDataFeed(DataFeed* data_feed); - virtual void SetNeedDump(bool need_dump_field) {} - virtual void SetChannelWriter(ChannelObject* queue) {} + virtual void SetNeedDumpField(bool need_dump_field) { + need_dump_field_ = need_dump_field; + } + virtual void SetNeedDumpParam(bool need_dump_param) { + need_dump_param_ = need_dump_param; + } + virtual void SetDumpFieldVector(const std::vector& dump_fields) { + dump_fields_ = &dump_fields; + } + virtual void SetDumpParamVector(const std::vector& dump_param) { + dump_param_ = &dump_param; + } + virtual void SetChannelWriter(ChannelObject* queue) { + writer_.Reset(queue); + } virtual void SetPlace(const paddle::platform::Place& place) { place_ = place; } @@ -140,6 +154,9 @@ class DeviceWorker { virtual Scope* GetThreadScope() { return thread_scope_; } protected: + virtual void DumpParam(const Scope& scope, const int batch_id); + virtual void DumpField(const Scope& scope, int dump_mode, + int dump_interval = 10000); Scope* root_scope_ = nullptr; Scope* thread_scope_; paddle::platform::Place place_; @@ -148,6 +165,16 @@ class DeviceWorker { FetchConfig fetch_config_; bool use_cvm_; bool no_cvm_; + + // dump params or grads for debug + bool need_dump_param_; + bool need_dump_field_; + const std::vector* dump_param_; + const std::vector* dump_fields_; + + int dump_mode_ = 0; + int dump_interval_ = 10000; + ChannelWriter writer_; }; class CPUWorkerBase : public DeviceWorker { @@ -176,8 +203,6 @@ class HogwildWorker : public CPUWorkerBase { virtual void Initialize(const TrainerDesc& desc); virtual void TrainFiles(); virtual void TrainFilesWithProfiler(); - virtual void SetNeedDump(bool need_dump_field); - virtual void SetChannelWriter(ChannelObject* queue); virtual void PrintFetchVars(); virtual void CreateDeviceResource(const ProgramDesc& main_prog); virtual void BindingDataFeedMemory(); @@ -187,7 +212,6 @@ class HogwildWorker : public CPUWorkerBase { protected: void CreateThreadOperators(const ProgramDesc& program); void CreateThreadScope(const ProgramDesc& program); - virtual void DumpParam(const int batch_id); std::vector op_names_; std::vector ops_; @@ -196,12 +220,6 @@ class HogwildWorker : public CPUWorkerBase { HogwildWorkerParameter param_; std::vector skip_ops_; std::map stat_var_name_map_; - // dump params or grads for debug - bool need_dump_param_; - bool need_dump_field_; - std::vector dump_param_; - std::vector dump_fields_; - ChannelWriter writer_; }; class DownpourWorker : public HogwildWorker { @@ -211,8 +229,6 @@ class DownpourWorker : public HogwildWorker { virtual void Initialize(const TrainerDesc& desc); virtual void TrainFiles(); virtual void TrainFilesWithProfiler(); - virtual void SetNeedDump(bool need_dump_field); - virtual void SetChannelWriter(ChannelObject* queue); protected: std::shared_ptr fleet_ptr_; @@ -224,7 +240,6 @@ class DownpourWorker : public HogwildWorker { void CopySparseTable(); void CopyDenseTable(); void CopyDenseVars(); - virtual void DumpParam(const int batch_id); DownpourWorkerParameter param_; // copy table diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 9fe28bddd1f04a80c1ede7466bf6c881c0f6c817..6ed68bb09644b7b9984ebf0df656256622a332f4 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -29,18 +29,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, thread_num_ = trainer_desc.thread_num(); SetDataset(dataset); - dump_fields_path_ = trainer_desc.dump_fields_path(); - dump_converter_ = trainer_desc.dump_converter(); - need_dump_field_ = false; - if (trainer_desc.dump_fields_size() != 0 && dump_fields_path_ != "") { - need_dump_field_ = true; - } - if (need_dump_field_) { - auto &file_list = dataset->GetFileList(); - if (file_list.size() == 0) { - need_dump_field_ = false; - } - } + ParseDumpConfig(trainer_desc); mpi_rank_ = trainer_desc.mpi_rank(); mpi_size_ = trainer_desc.mpi_size(); dump_file_num_ = trainer_desc.dump_file_num(); @@ -60,8 +49,12 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, trainer_desc.device_worker_name()); workers_[i]->SetDeviceIndex(i); workers_[i]->SetDataFeed(readers[i]); + workers_[i]->SetNeedDumpField(need_dump_field_); + workers_[i]->SetNeedDumpParam(need_dump_param_); + workers_[i]->SetDumpFieldVector(dump_fields_); + workers_[i]->SetDumpParamVector(dump_param_); + workers_[i]->InitRandomDumpConfig(trainer_desc); workers_[i]->Initialize(trainer_desc); - workers_[i]->SetNeedDump(need_dump_field_); } VLOG(3) << "going to initialize pull dense worker"; @@ -71,33 +64,6 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, SetDebug(trainer_desc.debug()); } -void DistMultiTrainer::DumpWork(int tid) { -#ifdef _LINUX - int err_no = 0; - std::string path = string::format_string( - "%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid); - - std::shared_ptr fp = fs_open_write(path, &err_no, dump_converter_); - while (1) { - std::string out_str; - if (!queue_->Get(out_str)) { - break; - } - size_t write_count = - fwrite_unlocked(out_str.data(), 1, out_str.length(), fp.get()); - if (write_count != out_str.length()) { - VLOG(3) << "dump text failed"; - continue; - } - write_count = fwrite_unlocked("\n", 1, 1, fp.get()); - if (write_count != 1) { - VLOG(3) << "dump text failed"; - continue; - } - } -#endif -} - void DistMultiTrainer::InitDumpEnv() { queue_ = paddle::framework::MakeChannel(); for (int i = 0; i < thread_num_; ++i) { @@ -112,16 +78,8 @@ void DistMultiTrainer::InitDumpEnv() { } for (int i = 0; i < dump_thread_num_; i++) { dump_thread_.push_back( - std::thread(std::bind(&DistMultiTrainer::DumpWork, this, i))); - } -} - -void DistMultiTrainer::FinalizeDumpEnv() { - queue_->Close(); - for (auto &th : dump_thread_) { - th.join(); + std::thread(std::bind(&TrainerBase::DumpWork, this, i))); } - queue_.reset(); } void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program, diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index b1a1b73a66e72d95c68089832b0f0381e9382f95..243e7b97c2a75a46c37ad6e72c8de34838680b03 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -80,19 +80,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { no_cvm_ = desc.no_cvm(); scale_datanorm_ = desc.scale_datanorm(); dump_slot_ = desc.dump_slot(); - dump_fields_.resize(desc.dump_fields_size()); - for (int i = 0; i < desc.dump_fields_size(); ++i) { - dump_fields_[i] = desc.dump_fields(i); - } adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); - need_dump_param_ = false; - dump_param_.resize(desc.dump_param_size()); - for (int i = 0; i < desc.dump_param_size(); ++i) { - dump_param_[i] = desc.dump_param(i); - } - if (desc.dump_param_size() != 0) { - need_dump_param_ = true; - } for (int i = 0; i < desc.check_nan_var_names_size(); ++i) { check_nan_var_names_.push_back(desc.check_nan_var_names(i)); } @@ -121,30 +109,6 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { } } -void DownpourWorker::SetChannelWriter(ChannelObject* queue) { - writer_.Reset(queue); -} - -void DownpourWorker::SetNeedDump(bool need_dump_field) { - need_dump_field_ = need_dump_field; -} - -void DownpourWorker::DumpParam(const int batch_id) { - std::ostringstream os; - for (auto& param : dump_param_) { - os.str(""); - Variable* var = thread_scope_->FindVar(param); - if (var == nullptr) { - continue; - } - LoDTensor* tensor = var->GetMutable(); - int64_t len = tensor->numel(); - os << "(" << batch_id << "," << param << ")" - << PrintLodTensor(tensor, 0, len); - writer_ << os.str(); - } -} - void DownpourWorker::CollectLabelInfo(size_t table_idx) { if (no_cvm_) { return; @@ -915,52 +879,17 @@ void DownpourWorker::TrainFiles() { } } if (need_dump_field_) { - size_t batch_size = device_reader_->GetCurBatchSize(); - std::vector ars(batch_size); - for (auto& ar : ars) { - ar.clear(); - } - auto& ins_id_vec = device_reader_->GetInsIdVec(); - auto& ins_content_vec = device_reader_->GetInsContentVec(); - for (size_t i = 0; i < ins_id_vec.size(); i++) { - ars[i] += ins_id_vec[i]; - ars[i] = ars[i] + "\t" + ins_content_vec[i]; - } - for (auto& field : dump_fields_) { - Variable* var = thread_scope_->FindVar(field); - if (var == nullptr) { - continue; - } - LoDTensor* tensor = var->GetMutable(); - if (!CheckValidOutput(tensor, batch_size)) { - continue; - } - for (size_t i = 0; i < batch_size; ++i) { - auto output_dim = tensor->dims()[1]; - std::string output_dimstr = - boost::lexical_cast(output_dim); - ars[i] = ars[i] + "\t" + field + ":" + output_dimstr; - auto bound = GetTensorBound(tensor, i); - ars[i] += PrintLodTensor(tensor, bound.first, bound.second); - } - } - // #pragma omp parallel for - for (size_t i = 0; i < ars.size(); i++) { - if (ars[i].length() == 0) { - continue; - } - writer_ << ars[i]; - } - if (need_dump_param_ && thread_id_ == 0) { - DumpParam(batch_cnt); - } + DumpField(*thread_scope_, dump_mode_, dump_interval_); + } + if (need_dump_param_ && thread_id_ == 0) { + DumpParam(*thread_scope_, batch_cnt); } PrintFetchVars(); thread_scope_->DropKids(); ++batch_cnt; } - if (need_dump_field_) { + if (need_dump_field_ || need_dump_param_) { writer_.Flush(); } if (copy_table_config_.need_copy()) { diff --git a/paddle/fluid/framework/downpour_worker_opt.cc b/paddle/fluid/framework/downpour_worker_opt.cc index 79f80a373a26af241d9d1a3d62010d7b1520d85d..b40a00ef9cb8cf3f51fdca4d71a905ac912db51f 100644 --- a/paddle/fluid/framework/downpour_worker_opt.cc +++ b/paddle/fluid/framework/downpour_worker_opt.cc @@ -156,19 +156,7 @@ void DownpourWorkerOpt::Initialize(const TrainerDesc& desc) { no_cvm_ = desc.no_cvm(); scale_datanorm_ = desc.scale_datanorm(); dump_slot_ = desc.dump_slot(); - dump_fields_.resize(desc.dump_fields_size()); - for (int i = 0; i < desc.dump_fields_size(); ++i) { - dump_fields_[i] = desc.dump_fields(i); - } adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); - need_dump_param_ = false; - dump_param_.resize(desc.dump_param_size()); - for (int i = 0; i < desc.dump_param_size(); ++i) { - dump_param_[i] = desc.dump_param(i); - } - if (desc.dump_param_size() != 0) { - need_dump_param_ = true; - } for (int i = 0; i < desc.loss_names_size(); ++i) { loss_names_.push_back(desc.loss_names(i)); } @@ -527,52 +515,17 @@ void DownpourWorkerOpt::TrainFiles() { } } if (need_dump_field_) { - size_t batch_size = device_reader_->GetCurBatchSize(); - std::vector ars(batch_size); - for (auto& ar : ars) { - ar.clear(); - } - auto& ins_id_vec = device_reader_->GetInsIdVec(); - auto& ins_content_vec = device_reader_->GetInsContentVec(); - for (size_t i = 0; i < ins_id_vec.size(); i++) { - ars[i] += ins_id_vec[i]; - ars[i] = ars[i] + "\t" + ins_content_vec[i]; - } - for (auto& field : dump_fields_) { - Variable* var = thread_scope_->FindVar(field); - if (var == nullptr) { - continue; - } - LoDTensor* tensor = var->GetMutable(); - if (!CheckValidOutput(tensor, batch_size)) { - continue; - } - for (size_t i = 0; i < batch_size; ++i) { - auto output_dim = tensor->dims()[1]; - std::string output_dimstr = - boost::lexical_cast(output_dim); - ars[i] = ars[i] + "\t" + field + ":" + output_dimstr; - auto bound = GetTensorBound(tensor, i); - ars[i] += PrintLodTensor(tensor, bound.first, bound.second); - } - } - // #pragma omp parallel for - for (size_t i = 0; i < ars.size(); i++) { - if (ars[i].length() == 0) { - continue; - } - writer_ << ars[i]; - } - if (need_dump_param_ && thread_id_ == 0) { - DumpParam(batch_cnt); - } + DumpField(*thread_scope_, dump_mode_, dump_interval_); + } + if (need_dump_param_ && thread_id_ == 0) { + DumpParam(*thread_scope_, batch_cnt); } PrintFetchVars(); thread_scope_->DropKids(); ++batch_cnt; } - if (need_dump_field_) { + if (need_dump_field_ || need_dump_param_) { writer_.Flush(); } if (copy_table_config_.need_copy()) { diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index db6231e99193d96714e4205fa174444a4ffede83..4d930337e845db9cea7aa8af0c5d3acbd15b11f0 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -32,23 +32,9 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) { use_cvm_ = desc.use_cvm(); thread_barrier_ = desc.thread_barrier(); - dump_fields_.resize(desc.dump_fields_size()); - for (int i = 0; i < desc.dump_fields_size(); ++i) { - dump_fields_[i] = desc.dump_fields(i); - } - for (int i = 0; i < param_.stat_var_names_size(); ++i) { stat_var_name_map_[param_.stat_var_names(i)] = 1; } - - need_dump_param_ = false; - dump_param_.resize(desc.dump_param_size()); - for (int i = 0; i < desc.dump_param_size(); ++i) { - dump_param_[i] = desc.dump_param(i); - } - if (desc.dump_param_size() != 0) { - need_dump_param_ = true; - } } void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) { @@ -163,45 +149,10 @@ void HogwildWorker::TrainFilesWithProfiler() { } if (need_dump_field_) { - size_t batch_size = device_reader_->GetCurBatchSize(); - std::vector ars(batch_size); - for (auto &ar : ars) { - ar.clear(); - } - auto &ins_id_vec = device_reader_->GetInsIdVec(); - auto &ins_content_vec = device_reader_->GetInsContentVec(); - for (size_t i = 0; i < ins_id_vec.size(); i++) { - ars[i] += ins_id_vec[i]; - ars[i] = ars[i] + "\t" + ins_content_vec[i]; - } - for (auto &field : dump_fields_) { - Variable *var = thread_scope_->FindVar(field); - if (var == nullptr) { - continue; - } - LoDTensor *tensor = var->GetMutable(); - if (!CheckValidOutput(tensor, batch_size)) { - continue; - } - for (size_t i = 0; i < batch_size; ++i) { - auto output_dim = tensor->dims()[1]; - std::string output_dimstr = - boost::lexical_cast(output_dim); - ars[i] = ars[i] + "\t" + field + ":" + output_dimstr; - auto bound = GetTensorBound(tensor, i); - ars[i] += PrintLodTensor(tensor, bound.first, bound.second); - } - } - // #pragma omp parallel for - for (size_t i = 0; i < ars.size(); i++) { - if (ars[i].length() == 0) { - continue; - } - writer_ << ars[i]; - } - if (need_dump_param_ && thread_id_ == 0) { - DumpParam(batch_cnt); - } + DumpField(*thread_scope_, dump_mode_, dump_interval_); + } + if (need_dump_param_ && thread_id_ == 0) { + DumpParam(*thread_scope_, batch_cnt); } total_inst += cur_batch; @@ -222,7 +173,7 @@ void HogwildWorker::TrainFilesWithProfiler() { timeline.Start(); } - if (need_dump_field_) { + if (need_dump_field_ || need_dump_param_) { writer_.Flush(); } @@ -234,10 +185,6 @@ void HogwildWorker::TrainFilesWithProfiler() { #endif } -void HogwildWorker::SetChannelWriter(ChannelObject *queue) { - writer_.Reset(queue); -} - void HogwildWorker::TrainFiles() { platform::SetNumThreads(1); @@ -284,25 +231,5 @@ void HogwildWorker::PrintFetchVars() { } } -void HogwildWorker::SetNeedDump(bool need_dump_field) { - need_dump_field_ = need_dump_field; -} - -void HogwildWorker::DumpParam(const int batch_id) { - std::ostringstream os; - for (auto ¶m : dump_param_) { - os.str(""); - Variable *var = thread_scope_->FindVar(param); - if (var == nullptr) { - continue; - } - LoDTensor *tensor = var->GetMutable(); - int64_t len = tensor->numel(); - os << "(" << batch_id << "," << param << ")" - << PrintLodTensor(tensor, 0, len); - writer_ << os.str(); - } -} - } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 0faf96195403faeead00c56353cd5ad965269e13..4ffd9a2f9cbe036bb80512339cf832d1ea1c53bb 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include #include -#include "io/fs.h" #include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" @@ -28,18 +27,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, thread_num_ = trainer_desc.thread_num(); SetDataset(dataset); - dump_fields_path_ = trainer_desc.dump_fields_path(); - dump_converter_ = trainer_desc.dump_converter(); - need_dump_field_ = false; - if (trainer_desc.dump_fields_size() != 0 && dump_fields_path_ != "") { - need_dump_field_ = true; - } - if (need_dump_field_) { - auto& file_list = dataset->GetFileList(); - if (file_list.size() == 0) { - need_dump_field_ = false; - } - } + ParseDumpConfig(trainer_desc); mpi_rank_ = trainer_desc.mpi_rank(); mpi_size_ = trainer_desc.mpi_size(); dump_file_num_ = trainer_desc.dump_file_num(); @@ -68,41 +56,23 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); + workers_[i]->SetNeedDumpField(need_dump_field_); + workers_[i]->SetNeedDumpParam(need_dump_param_); + workers_[i]->SetDumpFieldVector(dump_fields_); + workers_[i]->SetDumpParamVector(dump_param_); + workers_[i]->InitRandomDumpConfig(trainer_desc); workers_[i]->Initialize(trainer_desc); workers_[i]->SetDeviceIndex(i); workers_[i]->SetDataFeed(readers[i]); - workers_[i]->SetNeedDump(need_dump_field_); } // set debug here SetDebug(trainer_desc.debug()); } -void MultiTrainer::DumpWork(int tid) { -#ifdef _LINUX - int err_no = 0; - std::string path = string::format_string( - "%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid); - - std::shared_ptr fp = fs_open_write(path, &err_no, dump_converter_); - while (1) { - std::string out_str; - if (!queue_->Get(out_str)) { - break; - } - size_t write_count = - fwrite_unlocked(out_str.data(), 1, out_str.length(), fp.get()); - if (write_count != out_str.length()) { - VLOG(3) << "dump text failed"; - continue; - } - write_count = fwrite_unlocked("\n", 1, 1, fp.get()); - if (write_count != 1) { - VLOG(3) << "dump text failed"; - continue; - } - } -#endif +std::string MultiTrainer::GetDumpPath(int tid) { + return string::format_string("%s/part-%03d-%05d", dump_fields_path_.c_str(), + mpi_rank_, tid); } void MultiTrainer::InitDumpEnv() { @@ -119,16 +89,8 @@ void MultiTrainer::InitDumpEnv() { } for (int i = 0; i < dump_thread_num_; i++) { dump_thread_.push_back( - std::thread(std::bind(&MultiTrainer::DumpWork, this, i))); - } -} - -void MultiTrainer::FinalizeDumpEnv() { - queue_->Close(); - for (auto& th : dump_thread_) { - th.join(); + std::thread(std::bind(&TrainerBase::DumpWork, this, i))); } - queue_.reset(); } // call only after all resources are set in current trainer diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 478d8c6143655041bb35452cb3f22e6668a035cf..47e962a4825369020535905dab2859fd9be0398b 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -27,6 +27,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, VLOG(3) << "pipeline num: " << pipeline_num_; SetDataset(dataset); + ParseDumpConfig(trainer_desc); // get filelist from trainer_desc here const std::vector readers = dataset->GetReaders(); @@ -103,8 +104,15 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, this_worker->SetDataFeed(readers[reader_index++]); this_worker->SetReaderPlace(place); } + if (i == section_num_ - 1) { + this_worker->SetNeedDumpField(need_dump_field_); + this_worker->SetNeedDumpParam(need_dump_param_); + this_worker->SetDumpFieldVector(dump_fields_); + this_worker->SetDumpParamVector(dump_param_); + } this_worker->SetPlace(place); this_worker->Initialize(trainer_desc); + this_worker->InitRandomDumpConfig(trainer_desc); } } } @@ -119,6 +127,33 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, SetDebug(trainer_desc.debug()); } +void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { + if (need_dump_field_) { + InitDumpEnv(); + } + VLOG(3) << "init other env done."; +} + +std::string PipelineTrainer::GetDumpPath(int tid) { + return string::format_string("%s/part-%05d", dump_fields_path_.c_str(), tid); +} + +void PipelineTrainer::InitDumpEnv() { + queue_ = paddle::framework::MakeChannel(); + // Only set dump channel on the last section + for (int j = 0; j < pipeline_num_; ++j) { + for (size_t k = 0; k < workers_[section_num_ - 1][j].size(); ++k) { + workers_[section_num_ - 1][j][k]->SetChannelWriter(queue_.get()); + } + } + // TODO(hutuxian): should make it as a config + dump_thread_num_ = 1; + for (int i = 0; i < dump_thread_num_; i++) { + dump_thread_.push_back( + std::thread(std::bind(&TrainerBase::DumpWork, this, i))); + } +} + void PipelineTrainer::InitFirstScopeQueue(ScopeQueue* scope_queue, int pipeline_id, const ProgramDesc& main_program, @@ -271,6 +306,9 @@ void PipelineTrainer::Finalize() { for (auto& th : section_threads_) { th.join(); } + if (need_dump_field_) { + FinalizeDumpEnv(); + } for (const auto& var : persistable_vars_) { auto* root_tensor = root_scope_->Var(var)->GetMutable(); // TODO(hutuxian): Add a final all-reduce? diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 01d07f9b2eb61263237dc021c605cd24b6a444b2..1d644cdd7fb76ff731c4533b3129ad3fa2c724c2 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -95,11 +95,11 @@ void SyncFunctor::Synchronize() { } std::atomic SectionWorker::cpu_id_(0); -void SectionWorker::Initialize(const TrainerDesc& trainer_desc) { +void SectionWorker::Initialize(const TrainerDesc& desc) { dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); std::shared_ptr program; program.reset(new ProgramDesc( - trainer_desc.section_param().section_config(section_id_).program_desc())); + desc.section_param().section_config(section_id_).program_desc())); for (auto& op_desc : program->Block(0).AllOps()) { ops_.push_back(OpRegistry::CreateOp(*op_desc)); } @@ -373,6 +373,12 @@ void SectionWorker::TrainFilesWithProfiler() { metric_msg->add_data(exe_scope); } #endif + if (need_dump_field_) { + DumpField(*scope, dump_mode_, dump_interval_); + } + if (need_dump_param_ && pipeline_id_ == 0) { + DumpParam(*scope, step_cnt); + } if (section_id_ != section_num_ - 1 && platform::is_gpu_place(place_)) { // FIXME: Temporarily we assume two adjacent sections are in different @@ -410,6 +416,9 @@ void SectionWorker::TrainFilesWithProfiler() { accum_num += batch_size; main_timer.Pause(); } + if (need_dump_field_ || need_dump_param_) { + writer_.Flush(); + } outer_timer.Pause(); worker_count_mutex_->lock(); diff --git a/paddle/fluid/framework/trainer.cc b/paddle/fluid/framework/trainer.cc index 644bd33a1420aa0ff54e34005eedd10c28342665..99a1589200f72ef6fa33c03c0a72f27482e149e0 100644 --- a/paddle/fluid/framework/trainer.cc +++ b/paddle/fluid/framework/trainer.cc @@ -13,11 +13,77 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/trainer.h" +#include "io/fs.h" namespace paddle { namespace framework { void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; } +void TrainerBase::ParseDumpConfig(const TrainerDesc& desc) { + dump_fields_path_ = desc.dump_fields_path(); + if (dump_fields_path_ == "") { + VLOG(2) << "dump_fields_path_ is empty"; + return; + } + auto& file_list = dataset_ptr_->GetFileList(); + if (file_list.size() == 0) { + VLOG(2) << "file_list is empty"; + return; + } + + dump_converter_ = desc.dump_converter(); + if (desc.dump_fields_size() != 0) { + need_dump_field_ = true; + dump_fields_.resize(desc.dump_fields_size()); + for (int i = 0; i < desc.dump_fields_size(); ++i) { + dump_fields_[i] = desc.dump_fields(i); + } + } + + if (desc.dump_param_size() != 0) { + need_dump_param_ = true; + dump_param_.resize(desc.dump_param_size()); + for (int i = 0; i < desc.dump_param_size(); ++i) { + dump_param_[i] = desc.dump_param(i); + } + } +} + +void TrainerBase::DumpWork(int tid) { +#ifdef _LINUX + int err_no = 0; + // GetDumpPath is implemented in each Trainer + std::string path = GetDumpPath(tid); + + std::shared_ptr fp = fs_open_write(path, &err_no, dump_converter_); + while (1) { + std::string out_str; + if (!queue_->Get(out_str)) { + break; + } + size_t write_count = + fwrite_unlocked(out_str.data(), 1, out_str.length(), fp.get()); + if (write_count != out_str.length()) { + VLOG(3) << "dump text failed"; + continue; + } + write_count = fwrite_unlocked("\n", 1, 1, fp.get()); + if (write_count != 1) { + VLOG(3) << "dump text failed"; + continue; + } + } +#endif +} + +void TrainerBase::FinalizeDumpEnv() { + queue_->Close(); + for (auto& th : dump_thread_) { + th.join(); + } + queue_.reset(); +} + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index e22d659a367df8d1c6daf24b989cf5420b5609d3..c18ea33d041b9518fb60d2453830de8e4b4ff033 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -51,11 +51,28 @@ class TrainerBase { virtual void Run() = 0; virtual void Finalize() = 0; virtual Scope* GetWorkerScope(int thread_id) = 0; + virtual void InitDumpEnv() = 0; + virtual void DumpWork(int tid); protected: + virtual std::string GetDumpPath(int tid) = 0; + virtual void ParseDumpConfig(const TrainerDesc& trainer_desc); + virtual void FinalizeDumpEnv(); + Scope* root_scope_; bool debug_; Dataset* dataset_ptr_; + + // For dump param or field + bool need_dump_field_ = false; + bool need_dump_param_ = false; + std::string dump_fields_path_; + std::string dump_converter_; + std::vector dump_param_; + std::vector dump_fields_; + int dump_thread_num_; + std::vector dump_thread_; + std::shared_ptr> queue_; }; // general trainer for async execution @@ -71,10 +88,9 @@ class MultiTrainer : public TrainerBase { virtual void InitOtherEnv(const ProgramDesc& main_program); virtual void Run(); virtual void Finalize(); - virtual void FinalizeDumpEnv(); virtual void InitDumpEnv(); virtual Scope* GetWorkerScope(int thread_id); - virtual void DumpWork(int tid); + virtual std::string GetDumpPath(int tid); protected: int thread_num_; @@ -83,16 +99,9 @@ class MultiTrainer : public TrainerBase { std::vector> workers_; std::vector need_merge_var_names_; - bool need_dump_field_; - std::string dump_fields_path_; - std::string dump_converter_; int mpi_rank_; int mpi_size_; int dump_file_num_; - - std::vector dump_thread_; - int dump_thread_num_; - std::shared_ptr> queue_; }; class DistMultiTrainer : public MultiTrainer { @@ -107,10 +116,8 @@ class DistMultiTrainer : public MultiTrainer { virtual void Finalize(); template void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor); - virtual void FinalizeDumpEnv(); virtual void InitDumpEnv(); virtual Scope* GetWorkerScope(int thread_id); - virtual void DumpWork(int tid); protected: std::shared_ptr pull_dense_worker_; @@ -124,10 +131,12 @@ class PipelineTrainer : public TrainerBase { void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) override; void InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place) override; - void InitOtherEnv(const ProgramDesc& main_program) override {} + void InitOtherEnv(const ProgramDesc& main_program) override; void Run() override; void Finalize() override; virtual Scope* GetWorkerScope(int thread_id); + void InitDumpEnv() override; + virtual std::string GetDumpPath(int tid); protected: int section_num_; diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index f442063313f03321931112ed293ccdf8ebabeb89..9cbb063a3fab6810709c1504deed2ccf40743123 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -49,6 +49,9 @@ message TrainerDesc { optional bool no_cvm = 21 [ default = false ]; optional bool thread_barrier = 22; repeated string loss_names = 23; + optional bool enable_random_dump = 24 [ default = false ]; + optional bool random_with_lineid = 25 [ default = false ]; + optional int32 dump_interval = 26 [ default = 10000 ]; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; diff --git a/python/paddle/fluid/tests/unittests/test_boxps.py b/python/paddle/fluid/tests/unittests/test_boxps.py index 17e378115a3ace9310785b24d64a0f320c8b1abf..4403f99b610d5a54a9741ed169eb8fabd77b0b15 100644 --- a/python/paddle/fluid/tests/unittests/test_boxps.py +++ b/python/paddle/fluid/tests/unittests/test_boxps.py @@ -16,6 +16,7 @@ import paddle.fluid as fluid import paddle.fluid.layers as layers import numpy as np import os +import shutil import paddle.fluid.core as core import unittest from paddle.fluid.layers.nn import _pull_box_sparse @@ -90,87 +91,105 @@ class TestBoxPSPreload(unittest.TestCase): """ TestCases for BoxPS Preload """ def test_boxps_cpu(self): - self.run_boxps_preload(True) + self.run_boxps_preload(True, True) + self.run_boxps_preload(True, False) def test_boxps_gpu(self): - self.run_boxps_preload(False) - - def run_boxps_preload(self, is_cpu=True): - x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0) - y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0) - emb_x, emb_y = _pull_box_sparse([x, y], size=2) - emb_xp = _pull_box_sparse(x, size=2) - concat = layers.concat([emb_x, emb_y], axis=1) - fc = layers.fc(input=concat, - name="fc", - size=1, - num_flatten_dims=1, - bias_attr=False) - loss = layers.reduce_mean(fc) - layers.Print(loss) - place = fluid.CPUPlace() if is_cpu or not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) - exe = fluid.Executor(place) - batch_size = 2 - - def binary_print(slot, fout): - fout.write(str(len(slot)) + " ") - for e in slot: - fout.write(str(e) + " ") - - batch1 = np.ones( - (batch_size, 2, 1)).astype("int64").reshape(batch_size, 2, 1) - filelist = [] - place_str = "cpu" if is_cpu else "gpu" - for i in range(2): - filelist.append("test_hdfs_" + place_str + "_" + str(i)) - for f in filelist: - with open(f, "w") as fout: - for ins in batch1: - for slot in ins: - binary_print(slot, fout) - fout.write("\n") - - def create_dataset(): - dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") - dataset.set_date("20190930") - dataset.set_use_var([x, y]) - dataset.set_batch_size(2) - dataset.set_thread(1) - dataset.set_filelist(filelist) - return dataset - - datasets = [] - datasets.append(create_dataset()) - datasets.append(create_dataset()) - optimizer = fluid.optimizer.SGD(learning_rate=0.5) - optimizer = fluid.optimizer.PipelineOptimizer( - optimizer, - cut_list=[], - place_list=[place], - concurrency_list=[1], - queue_size=1, - sync_steps=-1) - optimizer.minimize(loss) - exe.run(fluid.default_startup_program()) - datasets[0].load_into_memory() - datasets[0].begin_pass() - datasets[1].preload_into_memory() - exe.train_from_dataset( - program=fluid.default_main_program(), - dataset=datasets[0], - print_period=1) - datasets[0].end_pass(True) - datasets[1].wait_preload_done() - datasets[1].begin_pass() - exe.train_from_dataset( - program=fluid.default_main_program(), - dataset=datasets[1], - print_period=1, - debug=True) - datasets[1].end_pass(False) - for f in filelist: - os.remove(f) + self.run_boxps_preload(False, True) + self.run_boxps_preload(False, False) + + def run_boxps_preload(self, is_cpu=True, random_with_lineid=False): + program = fluid.Program() + with fluid.program_guard(program): + x = fluid.layers.data( + name='x', shape=[1], dtype='int64', lod_level=0) + y = fluid.layers.data( + name='y', shape=[1], dtype='int64', lod_level=0) + emb_x, emb_y = _pull_box_sparse([x, y], size=2) + emb_xp = _pull_box_sparse(x, size=2) + concat = layers.concat([emb_x, emb_y], axis=1) + fc = layers.fc(input=concat, + name="fc", + size=1, + num_flatten_dims=1, + bias_attr=False) + loss = layers.reduce_mean(fc) + layers.Print(loss) + place = fluid.CPUPlace( + ) if is_cpu or not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + exe = fluid.Executor(place) + batch_size = 100 + + def binary_print(slot, fout): + fout.write(str(len(slot)) + " ") + for e in slot: + fout.write(str(e) + " ") + + batch1 = np.ones( + (batch_size, 2, 1)).astype("int64").reshape(batch_size, 2, 1) + filelist = [] + place_str = "cpu" if is_cpu else "gpu" + for i in range(2): + filelist.append("test_hdfs_" + place_str + "_" + str(i)) + for f in filelist: + with open(f, "w") as fout: + for ins in batch1: + for slot in ins: + binary_print(slot, fout) + fout.write("\n") + + def create_dataset(): + dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + dataset.set_date("20190930") + dataset.set_use_var([x, y]) + dataset.set_batch_size(2) + dataset.set_thread(1) + dataset.set_filelist(filelist) + return dataset + + datasets = [] + datasets.append(create_dataset()) + datasets.append(create_dataset()) + optimizer = fluid.optimizer.SGD(learning_rate=0.5) + optimizer = fluid.optimizer.PipelineOptimizer( + optimizer, + cut_list=[], + place_list=[place], + concurrency_list=[1], + queue_size=1, + sync_steps=-1) + optimizer.minimize(loss) + + program._pipeline_opt[ + "dump_fields"] = ["fc.tmp_0", "fc.tmp_0@GRAD", "hehe"] + program._pipeline_opt["dump_fields_path"] = "./dump_log/" + program._pipeline_opt["dump_param"] = ["fc.w_0"] + program._pipeline_opt["enable_random_dump"] = True + program._pipeline_opt["dump_interval"] = 10 + program._pipeline_opt["random_with_lineid"] = random_with_lineid + + exe.run(fluid.default_startup_program()) + datasets[0].load_into_memory() + datasets[0].begin_pass() + datasets[1].preload_into_memory() + exe.train_from_dataset( + program=fluid.default_main_program(), + dataset=datasets[0], + print_period=1) + datasets[0].end_pass(True) + datasets[1].wait_preload_done() + datasets[1].begin_pass() + exe.train_from_dataset( + program=fluid.default_main_program(), + dataset=datasets[1], + print_period=1, + debug=True) + datasets[1].end_pass(False) + for f in filelist: + os.remove(f) + if os.path.isdir("dump_log"): + shutil.rmtree("dump_log") if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_pipeline.py b/python/paddle/fluid/tests/unittests/test_pipeline.py index 8c70a28be3743b96776b617be43b471c0371e89b..dbf14e047579437b2540ab2552e64ff8dc90099e 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline.py @@ -147,7 +147,7 @@ class TestPipeline(unittest.TestCase): for f in filelist: os.remove(f) - def test_pipeline_single_section(self): + def single_section(self, random_dump): program = fluid.Program() with fluid.program_guard(program): x = fluid.layers.data( @@ -179,11 +179,20 @@ class TestPipeline(unittest.TestCase): optimizer = fluid.optimizer.PipelineOptimizer( optimizer, cut_list=[], + #place_list=[fluid.CPUPlace()], place_list=[fluid.CUDAPlace(0)], concurrency_list=[1], queue_size=1, sync_steps=-1) optimizer.minimize(loss) + + program._pipeline_opt["dump_fields"] = ["fc.tmp_0", "fc.tmp_0@GRAD"] + program._pipeline_opt["dump_fields_path"] = "./dump_log/" + program._pipeline_opt["dump_param"] = ["embx"] + program._pipeline_opt["enable_random_dump"] = random_dump + program._pipeline_opt["dump_interval"] = 10 + program._pipeline_opt["random_with_lineid"] = False + #print(program._pipeline_opt) place = fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) @@ -225,13 +234,19 @@ class TestPipeline(unittest.TestCase): fluid.default_main_program(), dataset, thread=1, - debug=False, + debug=True, fetch_list=[], fetch_info=[], print_period=1) for f in filelist: os.remove(f) + if os.path.isdir("dump_log"): + shutil.rmtree("dump_log") + + def test_pipeline(self): + self.single_section(True) + self.single_section(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 2a206e7fa2f359d15fa28d5a44b7b7c2869f1dd1..ffef35b7acc2769495a91e412fa2373552a2f71e 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -107,6 +107,15 @@ class TrainerDesc(object): def _set_dump_converter(self, converter): self.proto_desc.dump_converter = converter + def _set_enable_random_dump(self, enable_random_dump): + self.proto_desc.enable_random_dump = enable_random_dump + + def _set_dump_interval(self, dump_interval): + self.proto_desc.dump_interval = dump_interval + + def _set_random_with_lineid(self, random_with_lineid): + self.proto_desc.random_with_lineid = random_with_lineid + def _set_dump_param(self, dump_param): for param in dump_param: self.proto_desc.dump_param.append(param) diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index 0e071251bb2cd319152e98d95a0632b996913cc6..143d0c9b7ddb00811775bc76f0629cb6b188df95 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -72,6 +72,14 @@ class TrainerFactory(object): trainer._set_dump_converter(opt_info["dump_converter"]) if opt_info.get("dump_param") is not None: trainer._set_dump_param(opt_info["dump_param"]) + if opt_info.get("enable_random_dump") is not None: + trainer._set_enable_random_dump(opt_info[ + "enable_random_dump"]) + if opt_info.get("dump_interval") is not None: + trainer._set_dump_interval(opt_info["dump_interval"]) + if opt_info.get("random_with_lineid") is not None: + trainer._set_random_with_lineid(opt_info[ + "random_with_lineid"]) if "fleet_desc" in opt_info: device_worker._set_fleet_desc(opt_info["fleet_desc"])