提交 79133eae 编写于 作者: R rensilin

clean code

Change-Id: I402094ec45a96482f0d57d1dbe37ac909e01246c
上级 d21f279a
......@@ -7,20 +7,25 @@ namespace custom_trainer {
namespace feed {
int EpochAccessor::initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr) {
_model_root_path = config["model_root_path"].as<std::string>() + "/";
_model_root_path = config["model_root_path"].as<std::string>();
_trainer_context = context_ptr.get();
if (context_ptr->file_system == nullptr) {
VLOG(0) << "file_system is not initialized";
return -1;
}
_done_file_path = _model_root_path;
if (config["donefile"]) {
_done_file_path.append(config["donefile"].as<std::string>());
_done_file_path = _trainer_context->file_system->path_join(_model_root_path, config["donefile"].as<std::string>());
} else {
_done_file_path.append("epoch_donefile.txt");
_done_file_path = _trainer_context->file_system->path_join(_model_root_path, "epoch_donefile.txt");
}
if (!context_ptr->file_system->exists(_done_file_path)) {
if (!_trainer_context->file_system->exists(_done_file_path)) {
VLOG(0) << "missing done file, path:" << _done_file_path;
}
std::string done_text = context_ptr->file_system->tail(_done_file_path);
std::string done_text = _trainer_context->file_system->tail(_done_file_path);
_done_status = paddle::string::split_string(done_text, std::string("\t"));
_current_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField);
_last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField);
......@@ -67,23 +72,25 @@ namespace feed {
if (epoch_id == 0) {
return false;
}
if (save_way == ModelSaveWay::ModelSaveInferenceDelta) {
return true;
} else if (save_way == ModelSaveWay::ModelSaveInferenceBase) {
return is_last_epoch(epoch_id);
} else if (save_way == ModelSaveWay::ModelSaveTrainCheckpoint) {
return ((epoch_id / 3600) % 8) == 0;
switch (save_way) {
case ModelSaveWay::ModelSaveInferenceDelta:
return true;
case ModelSaveWay::ModelSaveInferenceBase:
return is_last_epoch(epoch_id);
case ModelSaveWay::ModelSaveTrainCheckpoint:
return ((epoch_id / 3600) % 8) == 0;
}
return false;
}
std::string HourlyEpochAccessor::model_save_path(uint64_t epoch_id, ModelSaveWay save_way) {
if (save_way == ModelSaveWay::ModelSaveInferenceDelta) {
return _model_root_path + "/xbox/delta-" + std::to_string(epoch_id);
} else if (save_way == ModelSaveWay::ModelSaveInferenceBase) {
return _model_root_path + "/xbox/base";
} else if (save_way == ModelSaveWay::ModelSaveTrainCheckpoint) {
return _model_root_path + "/xbox/checkpoint";
switch (save_way) {
case ModelSaveWay::ModelSaveInferenceDelta:
return _trainer_context->file_system->path_join(_model_root_path, "/xbox/delta-" + std::to_string(epoch_id));
case ModelSaveWay::ModelSaveInferenceBase:
return _trainer_context->file_system->path_join(_model_root_path, "/xbox/base");
case ModelSaveWay::ModelSaveTrainCheckpoint:
return _trainer_context->file_system->path_join(_model_root_path, "/xbox/checkpoint");
}
return "";
}
......
......@@ -52,6 +52,7 @@ public:
virtual bool need_save_model(uint64_t epoch_id, ModelSaveWay save_way) = 0;
virtual std::string model_save_path(uint64_t epoch_id, ModelSaveWay save_way) = 0;
protected:
TrainerContext* _trainer_context;
std::string _done_file_path;
std::string _model_root_path;
uint64_t _current_epoch_id = 0;
......
......@@ -11,14 +11,14 @@ io :
ugis :
'default': 'feed_video,D3a0z8'
'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'
local :
default :
class : LocalFileSystem
buffer_size : 1024000
dataset :
data_list :
train_sample :
prefetch_num : 2
root_path : ./sample
root_path : [./sample]
data_spit_interval : 300
data_path_formater : '%Y%m%d/%H%M'
data_reader : LineDataReader
......
......@@ -7,14 +7,14 @@ namespace feed {
int Dataset::initialize(
const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
if (config["data_list"].Type() != YAML::NodeType::Map) {
VLOG(0) << "miss data_list config in dataset, or type error please check";
LOG(FATAL) << "miss data_list config in dataset, or type error please check";
return -1;
}
for (auto& data_config : config["data_list"]) {
std::string name = data_config.first.as<std::string>();
auto data_ptr = std::make_shared<DatasetContainer>();
if (data_ptr->initialize(data_config.second, context) != 0) {
VLOG(0) << "dataset initialize failed, name:" << name;
LOG(FATAL) << "dataset initialize failed, name:" << name;
return -1;
}
_data_containers[name] = data_ptr;
......
......@@ -6,10 +6,10 @@
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
namespace paddle {
......@@ -27,8 +27,7 @@ int DatasetContainer::initialize(
_dataset_list[i].reset(new DatasetInfo);
}
_data_root_paths = paddle::string::split_string(
config["root_path"].as<std::string>(), " ");
_data_root_paths = config["root_path"].as<std::vector<std::string>>();
_data_split_interval = config["data_spit_interval"].as<int>();
_data_path_formater = config["data_path_formater"].as<std::string>();
std::string data_reader_class = config["data_reader"].as<std::string>();
......@@ -66,7 +65,7 @@ void DatasetContainer::pre_detect_data(uint64_t epoch_id) {
for (int i = 0; i < _data_root_paths.size() && status == 0; ++i) {
for (int j = 0; j < data_num && status == 0; ++j) {
std::string path_suffix = format_timestamp(data_timestamp + j * _data_split_interval, _data_path_formater);
std::string data_dir = _data_root_paths[i] + "/" + path_suffix;
std::string data_dir = _trainer_context->file_system->path_join(_data_root_paths[i], path_suffix);
status = read_data_list(data_dir, data_path_list);
}
}
......
......@@ -17,7 +17,7 @@ namespace {
int ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
if (!fin) {
VLOG(2) << "Cannot open file " << filename;
LOG(FATAL) << "Cannot open file " << filename;
return -1;
}
fin.seekg(0, std::ios::end);
......@@ -31,7 +31,7 @@ int ReadBinaryFile(const std::string& filename, std::string* contents) {
std::unique_ptr<paddle::framework::ProgramDesc> Load(
paddle::framework::Executor* /*executor*/, const std::string& model_filename) {
VLOG(3) << "loading model from " << model_filename;
LOG(INFO) << "loading model from " << model_filename;
std::string program_desc_str;
if (ReadBinaryFile(model_filename, &program_desc_str) != 0) {
return nullptr;
......
......@@ -19,17 +19,18 @@ public:
for (auto& prefix_fs: config["file_systems"]) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as<std::string>("")));
if (fs == nullptr) {
VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as<std::string>("");
LOG(FATAL) << "fail to create class: " << prefix_fs.second["class"].as<std::string>("");
return -1;
}
if (fs->initialize(prefix_fs.second, context) != 0) {
VLOG(2) << "fail to initialize class: " << prefix_fs.second["class"].as<std::string>("");
return 0;
LOG(FATAL) << "fail to initialize class: " << prefix_fs.second["class"].as<std::string>("");
return -1;
}
_file_system.emplace(prefix_fs.first.as<std::string>(""), std::move(fs));
}
}
if (_file_system.find("default") == _file_system.end()) {
LOG(WARNING) << "miss default file_system, use LocalFileSystem as default";
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
if (fs == nullptr || fs->initialize(YAML::Load(""), context) != 0) {
return -1;
......@@ -82,7 +83,6 @@ public:
return fs_it->second.get();
}
}
VLOG(5) << "path: " << path << ", select default file system";
return _file_system["default"].get();
}
......
......@@ -25,7 +25,7 @@ public:
}
}
if (_ugi.find("default") == _ugi.end()) {
VLOG(2) << "fail to load default ugi";
LOG(FATAL) << "fail to load default ugi";
return -1;
}
return 0;
......@@ -62,7 +62,7 @@ public:
int64_t file_size(const std::string& path) override {
_err_no = -1;
VLOG(2) << "not support";
LOG(FATAL) << "not support";
return 0;
}
......
......@@ -13,6 +13,7 @@ using namespace paddle::custom_trainer::feed;
DEFINE_string(feed_trainer_conf_path, "./conf/trainer.yaml", "path of trainer conf");
int main(int argc, char* argv[]) {
google::InitGoogleLogging(argv[0]);
//gflags
google::ParseCommandLineFlags(&argc, &argv, true);
std::string gflag_conf = "./conf/gflags.conf";
......
......@@ -46,7 +46,7 @@ int InitEnvProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return -1;
}
VLOG(3) << "Env initialize success";
VLOG(3) << "Env initialize success";
return 0;
}
......
......@@ -76,7 +76,7 @@ int LearnerProcess::run() {
uint64_t epoch_id = epoch_accessor->current_epoch_id();
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Resume traine with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
"Resume trainer with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
//判断是否先dump出base
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
......
#!bash
#!/bin/bash
export LD_LIBRARY_PATH=LD_LIBRARY_PATH:./so
./bin/feed_trainer
./bin/feed_trainer "$@"
......@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::unique_ptr<paddle::framework::ProgramDesc> Load(
paddle::framework::Executor* executor, const std::string& model_filename) {
VLOG(3) << "loading model from " << model_filename;
LOG(DEBUG) << "loading model from " << model_filename;
std::string program_desc_str;
ReadBinaryFile(model_filename, &program_desc_str);
......
......@@ -133,9 +133,9 @@ TEST_F(CreateProgramsTest, example_network) {
auto output_var = executor->var<::paddle::framework::LoDTensor>(output_name);
auto output = output_var.data<float>()[0];
VLOG(3) << "loss: " << loss << std::endl;
VLOG(3) << "label: " << label_data[0] << std::endl;
VLOG(3) << "output: " << output << std::endl;
LOG(INFO) << "loss: " << loss << std::endl;
LOG(INFO) << "label: " << label_data[0] << std::endl;
LOG(INFO) << "output: " << output << std::endl;
ASSERT_NEAR(loss, pow(output - label_data[0], 2), 1e-8);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册