提交 825e947c 编写于 作者: R rensilin

dump params

Change-Id: I073a955ac9aa13e4afdfb869121b83f95062cdac
上级 79133eae
...@@ -14,18 +14,18 @@ namespace feed { ...@@ -14,18 +14,18 @@ namespace feed {
VLOG(0) << "file_system is not initialized"; VLOG(0) << "file_system is not initialized";
return -1; return -1;
} }
auto fs = _trainer_context->file_system.get();
if (config["donefile"]) { if (config["donefile"]) {
_done_file_path = _trainer_context->file_system->path_join(_model_root_path, config["donefile"].as<std::string>()); _done_file_path = fs->path_join(_model_root_path, config["donefile"].as<std::string>());
} else { } else {
_done_file_path = _trainer_context->file_system->path_join(_model_root_path, "epoch_donefile.txt"); _done_file_path = fs->path_join(_model_root_path, "epoch_donefile.txt");
} }
if (!_trainer_context->file_system->exists(_done_file_path)) { if (!fs->exists(_done_file_path)) {
VLOG(0) << "missing done file, path:" << _done_file_path; VLOG(0) << "missing done file, path:" << _done_file_path;
} }
std::string done_text = _trainer_context->file_system->tail(_done_file_path); std::string done_text = fs->tail(_done_file_path);
_done_status = paddle::string::split_string(done_text, std::string("\t")); _done_status = paddle::string::split_string(done_text, std::string("\t"));
_current_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField); _current_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField);
_last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField); _last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField);
......
...@@ -534,7 +534,7 @@ public: ...@@ -534,7 +534,7 @@ public:
size_t buffer_size = 0; size_t buffer_size = 0;
ssize_t line_len = 0; ssize_t line_len = 0;
while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) { while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
// 去掉行回车 // 去掉行回车
if (line_len > 0 && buffer[line_len - 1] == '\n') { if (line_len > 0 && buffer[line_len - 1] == '\n') {
buffer[--line_len] = '\0'; buffer[--line_len] = '\0';
} }
...@@ -547,7 +547,8 @@ public: ...@@ -547,7 +547,8 @@ public:
VLOG(5) << "parse data: " << data_item.id << " " << data_item.data << ", filename: " << filepath << ", thread_num: " << thread_num << ", max_threads: " << max_threads; VLOG(5) << "parse data: " << data_item.id << " " << data_item.data << ", filename: " << filepath << ", thread_num: " << thread_num << ", max_threads: " << max_threads;
if (writer == nullptr) { if (writer == nullptr) {
if (!data_channel->Put(std::move(data_item))) { if (!data_channel->Put(std::move(data_item))) {
VLOG(2) << "fail to put data, thread_num: " << thread_num; LOG(WARNING) << "fail to put data, thread_num: " << thread_num;
is_failed = true;
} }
} else { } else {
(*writer) << std::move(data_item); (*writer) << std::move(data_item);
......
...@@ -77,7 +77,7 @@ public: ...@@ -77,7 +77,7 @@ public:
FileSystem* get_file_system(const std::string& path) { FileSystem* get_file_system(const std::string& path) {
auto pos = path.find_first_of(":"); auto pos = path.find_first_of(":");
if (pos != std::string::npos) { if (pos != std::string::npos) {
auto substr = path.substr(0, pos + 1); auto substr = path.substr(0, pos); // example: afs:/xxx -> afs
auto fs_it = _file_system.find(substr); auto fs_it = _file_system.find(substr);
if (fs_it != _file_system.end()) { if (fs_it != _file_system.end()) {
return fs_it->second.get(); return fs_it->second.get();
......
...@@ -76,7 +76,7 @@ int LearnerProcess::run() { ...@@ -76,7 +76,7 @@ int LearnerProcess::run() {
uint64_t epoch_id = epoch_accessor->current_epoch_id(); uint64_t epoch_id = epoch_accessor->current_epoch_id();
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Resume trainer with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str()); "Resume training with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
//判断是否先dump出base //判断是否先dump出base
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
...@@ -108,7 +108,7 @@ int LearnerProcess::run() { ...@@ -108,7 +108,7 @@ int LearnerProcess::run() {
for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) { for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) {
train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) { train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) {
auto* executor = _threads_executor[thread_idx][exe_idx].get(); auto* executor = _threads_executor[thread_idx][exe_idx].get();
run_executor(executor); run_executor(executor);
}, i, thread_id)); }, i, thread_id));
} }
for (int i = 0; i < _train_thread_num; ++i) { for (int i = 0; i < _train_thread_num; ++i) {
......
...@@ -119,11 +119,12 @@ class ModelBuilder: ...@@ -119,11 +119,12 @@ class ModelBuilder:
'inputs': [{"name": var.name, "shape": var.shape} for var in inputs], 'inputs': [{"name": var.name, "shape": var.shape} for var in inputs],
'outputs': [{"name": var.name, "shape": var.shape} for var in outputs], 'outputs': [{"name": var.name, "shape": var.shape} for var in outputs],
'labels': [{"name": var.name, "shape": var.shape} for var in labels], 'labels': [{"name": var.name, "shape": var.shape} for var in labels],
'vars': [{"name": var.name, "shape": var.shape} for var in main_program.list_vars() if fluid.io.is_parameter(var)],
'loss': loss.name, 'loss': loss.name,
} }
with open(model_desc_path, 'w') as f: with open(model_desc_path, 'w') as f:
yaml.safe_dump(model_desc, f, encoding='utf-8', allow_unicode=True) yaml.safe_dump(model_desc, f, encoding='utf-8', allow_unicode=True, default_flow_style=None)
def main(argv): def main(argv):
......
...@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) { ...@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::unique_ptr<paddle::framework::ProgramDesc> Load( std::unique_ptr<paddle::framework::ProgramDesc> Load(
paddle::framework::Executor* executor, const std::string& model_filename) { paddle::framework::Executor* executor, const std::string& model_filename) {
LOG(DEBUG) << "loading model from " << model_filename; VLOG(3) << "loading model from " << model_filename;
std::string program_desc_str; std::string program_desc_str;
ReadBinaryFile(model_filename, &program_desc_str); ReadBinaryFile(model_filename, &program_desc_str);
......
...@@ -193,14 +193,14 @@ TEST_F(DataReaderTest, LineDataReader_FileSystem) { ...@@ -193,14 +193,14 @@ TEST_F(DataReaderTest, LineDataReader_FileSystem) {
"file_system:\n" "file_system:\n"
" class: AutoFileSystem\n" " class: AutoFileSystem\n"
" file_systems:\n" " file_systems:\n"
" 'afs:': &HDFS \n" " 'afs': &HDFS \n"
" class: HadoopFileSystem\n" " class: HadoopFileSystem\n"
" hdfs_command: 'hadoop fs'\n" " hdfs_command: 'hadoop fs'\n"
" ugis:\n" " ugis:\n"
" 'default': 'feed_video,D3a0z8'\n" " 'default': 'feed_video,D3a0z8'\n"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'\n" " 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'\n"
" \n" " \n"
" 'hdfs:': *HDFS\n"); " 'hdfs': *HDFS\n");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
{ {
auto data_file_list = data_reader->data_file_list(test_data_dir); auto data_file_list = data_reader->data_file_list(test_data_dir);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册