提交 825e947c 编写于 作者: R rensilin

dump params

Change-Id: I073a955ac9aa13e4afdfb869121b83f95062cdac
上级 79133eae
......@@ -14,18 +14,18 @@ namespace feed {
VLOG(0) << "file_system is not initialized";
return -1;
}
auto fs = _trainer_context->file_system.get();
if (config["donefile"]) {
_done_file_path = _trainer_context->file_system->path_join(_model_root_path, config["donefile"].as<std::string>());
_done_file_path = fs->path_join(_model_root_path, config["donefile"].as<std::string>());
} else {
_done_file_path = _trainer_context->file_system->path_join(_model_root_path, "epoch_donefile.txt");
_done_file_path = fs->path_join(_model_root_path, "epoch_donefile.txt");
}
if (!_trainer_context->file_system->exists(_done_file_path)) {
if (!fs->exists(_done_file_path)) {
VLOG(0) << "missing done file, path:" << _done_file_path;
}
std::string done_text = _trainer_context->file_system->tail(_done_file_path);
std::string done_text = fs->tail(_done_file_path);
_done_status = paddle::string::split_string(done_text, std::string("\t"));
_current_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField);
_last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField);
......
......@@ -534,7 +534,7 @@ public:
size_t buffer_size = 0;
ssize_t line_len = 0;
while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
// 去掉行回车
// 去掉行回车
if (line_len > 0 && buffer[line_len - 1] == '\n') {
buffer[--line_len] = '\0';
}
......@@ -547,7 +547,8 @@ public:
VLOG(5) << "parse data: " << data_item.id << " " << data_item.data << ", filename: " << filepath << ", thread_num: " << thread_num << ", max_threads: " << max_threads;
if (writer == nullptr) {
if (!data_channel->Put(std::move(data_item))) {
VLOG(2) << "fail to put data, thread_num: " << thread_num;
LOG(WARNING) << "fail to put data, thread_num: " << thread_num;
is_failed = true;
}
} else {
(*writer) << std::move(data_item);
......
......@@ -77,7 +77,7 @@ public:
FileSystem* get_file_system(const std::string& path) {
auto pos = path.find_first_of(":");
if (pos != std::string::npos) {
auto substr = path.substr(0, pos + 1);
auto substr = path.substr(0, pos); // example: afs:/xxx -> afs
auto fs_it = _file_system.find(substr);
if (fs_it != _file_system.end()) {
return fs_it->second.get();
......
......@@ -76,7 +76,7 @@ int LearnerProcess::run() {
uint64_t epoch_id = epoch_accessor->current_epoch_id();
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Resume trainer with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
"Resume training with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
//判断是否先dump出base
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
......
......@@ -119,11 +119,12 @@ class ModelBuilder:
'inputs': [{"name": var.name, "shape": var.shape} for var in inputs],
'outputs': [{"name": var.name, "shape": var.shape} for var in outputs],
'labels': [{"name": var.name, "shape": var.shape} for var in labels],
'vars': [{"name": var.name, "shape": var.shape} for var in main_program.list_vars() if fluid.io.is_parameter(var)],
'loss': loss.name,
}
with open(model_desc_path, 'w') as f:
yaml.safe_dump(model_desc, f, encoding='utf-8', allow_unicode=True)
yaml.safe_dump(model_desc, f, encoding='utf-8', allow_unicode=True, default_flow_style=None)
def main(argv):
......
......@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::unique_ptr<paddle::framework::ProgramDesc> Load(
paddle::framework::Executor* executor, const std::string& model_filename) {
LOG(DEBUG) << "loading model from " << model_filename;
VLOG(3) << "loading model from " << model_filename;
std::string program_desc_str;
ReadBinaryFile(model_filename, &program_desc_str);
......
......@@ -193,14 +193,14 @@ TEST_F(DataReaderTest, LineDataReader_FileSystem) {
"file_system:\n"
" class: AutoFileSystem\n"
" file_systems:\n"
" 'afs:': &HDFS \n"
" 'afs': &HDFS \n"
" class: HadoopFileSystem\n"
" hdfs_command: 'hadoop fs'\n"
" ugis:\n"
" 'default': 'feed_video,D3a0z8'\n"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'\n"
" \n"
" 'hdfs:': *HDFS\n");
" 'hdfs': *HDFS\n");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
{
auto data_file_list = data_reader->data_file_list(test_data_dir);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册