提交 79133eae 编写于 作者: R rensilin

clean code

Change-Id: I402094ec45a96482f0d57d1dbe37ac909e01246c
上级 d21f279a
...@@ -7,20 +7,25 @@ namespace custom_trainer { ...@@ -7,20 +7,25 @@ namespace custom_trainer {
namespace feed { namespace feed {
int EpochAccessor::initialize(YAML::Node config, int EpochAccessor::initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr) { std::shared_ptr<TrainerContext> context_ptr) {
_model_root_path = config["model_root_path"].as<std::string>() + "/"; _model_root_path = config["model_root_path"].as<std::string>();
_trainer_context = context_ptr.get();
if (context_ptr->file_system == nullptr) {
VLOG(0) << "file_system is not initialized";
return -1;
}
_done_file_path = _model_root_path;
if (config["donefile"]) { if (config["donefile"]) {
_done_file_path.append(config["donefile"].as<std::string>()); _done_file_path = _trainer_context->file_system->path_join(_model_root_path, config["donefile"].as<std::string>());
} else { } else {
_done_file_path.append("epoch_donefile.txt"); _done_file_path = _trainer_context->file_system->path_join(_model_root_path, "epoch_donefile.txt");
} }
if (!context_ptr->file_system->exists(_done_file_path)) { if (!_trainer_context->file_system->exists(_done_file_path)) {
VLOG(0) << "missing done file, path:" << _done_file_path; VLOG(0) << "missing done file, path:" << _done_file_path;
} }
std::string done_text = context_ptr->file_system->tail(_done_file_path); std::string done_text = _trainer_context->file_system->tail(_done_file_path);
_done_status = paddle::string::split_string(done_text, std::string("\t")); _done_status = paddle::string::split_string(done_text, std::string("\t"));
_current_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField); _current_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField);
_last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField); _last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField);
...@@ -67,23 +72,25 @@ namespace feed { ...@@ -67,23 +72,25 @@ namespace feed {
if (epoch_id == 0) { if (epoch_id == 0) {
return false; return false;
} }
if (save_way == ModelSaveWay::ModelSaveInferenceDelta) { switch (save_way) {
case ModelSaveWay::ModelSaveInferenceDelta:
return true; return true;
} else if (save_way == ModelSaveWay::ModelSaveInferenceBase) { case ModelSaveWay::ModelSaveInferenceBase:
return is_last_epoch(epoch_id); return is_last_epoch(epoch_id);
} else if (save_way == ModelSaveWay::ModelSaveTrainCheckpoint) { case ModelSaveWay::ModelSaveTrainCheckpoint:
return ((epoch_id / 3600) % 8) == 0; return ((epoch_id / 3600) % 8) == 0;
} }
return false; return false;
} }
std::string HourlyEpochAccessor::model_save_path(uint64_t epoch_id, ModelSaveWay save_way) { std::string HourlyEpochAccessor::model_save_path(uint64_t epoch_id, ModelSaveWay save_way) {
if (save_way == ModelSaveWay::ModelSaveInferenceDelta) { switch (save_way) {
return _model_root_path + "/xbox/delta-" + std::to_string(epoch_id); case ModelSaveWay::ModelSaveInferenceDelta:
} else if (save_way == ModelSaveWay::ModelSaveInferenceBase) { return _trainer_context->file_system->path_join(_model_root_path, "/xbox/delta-" + std::to_string(epoch_id));
return _model_root_path + "/xbox/base"; case ModelSaveWay::ModelSaveInferenceBase:
} else if (save_way == ModelSaveWay::ModelSaveTrainCheckpoint) { return _trainer_context->file_system->path_join(_model_root_path, "/xbox/base");
return _model_root_path + "/xbox/checkpoint"; case ModelSaveWay::ModelSaveTrainCheckpoint:
return _trainer_context->file_system->path_join(_model_root_path, "/xbox/checkpoint");
} }
return ""; return "";
} }
......
...@@ -52,6 +52,7 @@ public: ...@@ -52,6 +52,7 @@ public:
virtual bool need_save_model(uint64_t epoch_id, ModelSaveWay save_way) = 0; virtual bool need_save_model(uint64_t epoch_id, ModelSaveWay save_way) = 0;
virtual std::string model_save_path(uint64_t epoch_id, ModelSaveWay save_way) = 0; virtual std::string model_save_path(uint64_t epoch_id, ModelSaveWay save_way) = 0;
protected: protected:
TrainerContext* _trainer_context;
std::string _done_file_path; std::string _done_file_path;
std::string _model_root_path; std::string _model_root_path;
uint64_t _current_epoch_id = 0; uint64_t _current_epoch_id = 0;
......
...@@ -11,14 +11,14 @@ io : ...@@ -11,14 +11,14 @@ io :
ugis : ugis :
'default': 'feed_video,D3a0z8' 'default': 'feed_video,D3a0z8'
'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8' 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'
local : default :
class : LocalFileSystem class : LocalFileSystem
buffer_size : 1024000 buffer_size : 1024000
dataset : dataset :
data_list : data_list :
train_sample : train_sample :
prefetch_num : 2 prefetch_num : 2
root_path : ./sample root_path : [./sample]
data_spit_interval : 300 data_spit_interval : 300
data_path_formater : '%Y%m%d/%H%M' data_path_formater : '%Y%m%d/%H%M'
data_reader : LineDataReader data_reader : LineDataReader
......
...@@ -7,14 +7,14 @@ namespace feed { ...@@ -7,14 +7,14 @@ namespace feed {
int Dataset::initialize( int Dataset::initialize(
const YAML::Node& config, std::shared_ptr<TrainerContext> context) { const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
if (config["data_list"].Type() != YAML::NodeType::Map) { if (config["data_list"].Type() != YAML::NodeType::Map) {
VLOG(0) << "miss data_list config in dataset, or type error please check"; LOG(FATAL) << "miss data_list config in dataset, or type error please check";
return -1; return -1;
} }
for (auto& data_config : config["data_list"]) { for (auto& data_config : config["data_list"]) {
std::string name = data_config.first.as<std::string>(); std::string name = data_config.first.as<std::string>();
auto data_ptr = std::make_shared<DatasetContainer>(); auto data_ptr = std::make_shared<DatasetContainer>();
if (data_ptr->initialize(data_config.second, context) != 0) { if (data_ptr->initialize(data_config.second, context) != 0) {
VLOG(0) << "dataset initialize failed, name:" << name; LOG(FATAL) << "dataset initialize failed, name:" << name;
return -1; return -1;
} }
_data_containers[name] = data_ptr; _data_containers[name] = data_ptr;
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <yaml-cpp/yaml.h> #include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h" #include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h" #include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
namespace paddle { namespace paddle {
...@@ -27,8 +27,7 @@ int DatasetContainer::initialize( ...@@ -27,8 +27,7 @@ int DatasetContainer::initialize(
_dataset_list[i].reset(new DatasetInfo); _dataset_list[i].reset(new DatasetInfo);
} }
_data_root_paths = paddle::string::split_string( _data_root_paths = config["root_path"].as<std::vector<std::string>>();
config["root_path"].as<std::string>(), " ");
_data_split_interval = config["data_spit_interval"].as<int>(); _data_split_interval = config["data_spit_interval"].as<int>();
_data_path_formater = config["data_path_formater"].as<std::string>(); _data_path_formater = config["data_path_formater"].as<std::string>();
std::string data_reader_class = config["data_reader"].as<std::string>(); std::string data_reader_class = config["data_reader"].as<std::string>();
...@@ -66,7 +65,7 @@ void DatasetContainer::pre_detect_data(uint64_t epoch_id) { ...@@ -66,7 +65,7 @@ void DatasetContainer::pre_detect_data(uint64_t epoch_id) {
for (int i = 0; i < _data_root_paths.size() && status == 0; ++i) { for (int i = 0; i < _data_root_paths.size() && status == 0; ++i) {
for (int j = 0; j < data_num && status == 0; ++j) { for (int j = 0; j < data_num && status == 0; ++j) {
std::string path_suffix = format_timestamp(data_timestamp + j * _data_split_interval, _data_path_formater); std::string path_suffix = format_timestamp(data_timestamp + j * _data_split_interval, _data_path_formater);
std::string data_dir = _data_root_paths[i] + "/" + path_suffix; std::string data_dir = _trainer_context->file_system->path_join(_data_root_paths[i], path_suffix);
status = read_data_list(data_dir, data_path_list); status = read_data_list(data_dir, data_path_list);
} }
} }
......
...@@ -17,7 +17,7 @@ namespace { ...@@ -17,7 +17,7 @@ namespace {
int ReadBinaryFile(const std::string& filename, std::string* contents) { int ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary); std::ifstream fin(filename, std::ios::in | std::ios::binary);
if (!fin) { if (!fin) {
VLOG(2) << "Cannot open file " << filename; LOG(FATAL) << "Cannot open file " << filename;
return -1; return -1;
} }
fin.seekg(0, std::ios::end); fin.seekg(0, std::ios::end);
...@@ -31,7 +31,7 @@ int ReadBinaryFile(const std::string& filename, std::string* contents) { ...@@ -31,7 +31,7 @@ int ReadBinaryFile(const std::string& filename, std::string* contents) {
std::unique_ptr<paddle::framework::ProgramDesc> Load( std::unique_ptr<paddle::framework::ProgramDesc> Load(
paddle::framework::Executor* /*executor*/, const std::string& model_filename) { paddle::framework::Executor* /*executor*/, const std::string& model_filename) {
VLOG(3) << "loading model from " << model_filename; LOG(INFO) << "loading model from " << model_filename;
std::string program_desc_str; std::string program_desc_str;
if (ReadBinaryFile(model_filename, &program_desc_str) != 0) { if (ReadBinaryFile(model_filename, &program_desc_str) != 0) {
return nullptr; return nullptr;
......
...@@ -19,17 +19,18 @@ public: ...@@ -19,17 +19,18 @@ public:
for (auto& prefix_fs: config["file_systems"]) { for (auto& prefix_fs: config["file_systems"]) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as<std::string>(""))); std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as<std::string>("")));
if (fs == nullptr) { if (fs == nullptr) {
VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as<std::string>(""); LOG(FATAL) << "fail to create class: " << prefix_fs.second["class"].as<std::string>("");
return -1; return -1;
} }
if (fs->initialize(prefix_fs.second, context) != 0) { if (fs->initialize(prefix_fs.second, context) != 0) {
VLOG(2) << "fail to initialize class: " << prefix_fs.second["class"].as<std::string>(""); LOG(FATAL) << "fail to initialize class: " << prefix_fs.second["class"].as<std::string>("");
return 0; return -1;
} }
_file_system.emplace(prefix_fs.first.as<std::string>(""), std::move(fs)); _file_system.emplace(prefix_fs.first.as<std::string>(""), std::move(fs));
} }
} }
if (_file_system.find("default") == _file_system.end()) { if (_file_system.find("default") == _file_system.end()) {
LOG(WARNING) << "miss default file_system, use LocalFileSystem as default";
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
if (fs == nullptr || fs->initialize(YAML::Load(""), context) != 0) { if (fs == nullptr || fs->initialize(YAML::Load(""), context) != 0) {
return -1; return -1;
...@@ -82,7 +83,6 @@ public: ...@@ -82,7 +83,6 @@ public:
return fs_it->second.get(); return fs_it->second.get();
} }
} }
VLOG(5) << "path: " << path << ", select default file system";
return _file_system["default"].get(); return _file_system["default"].get();
} }
......
...@@ -25,7 +25,7 @@ public: ...@@ -25,7 +25,7 @@ public:
} }
} }
if (_ugi.find("default") == _ugi.end()) { if (_ugi.find("default") == _ugi.end()) {
VLOG(2) << "fail to load default ugi"; LOG(FATAL) << "fail to load default ugi";
return -1; return -1;
} }
return 0; return 0;
...@@ -62,7 +62,7 @@ public: ...@@ -62,7 +62,7 @@ public:
int64_t file_size(const std::string& path) override { int64_t file_size(const std::string& path) override {
_err_no = -1; _err_no = -1;
VLOG(2) << "not support"; LOG(FATAL) << "not support";
return 0; return 0;
} }
......
...@@ -13,6 +13,7 @@ using namespace paddle::custom_trainer::feed; ...@@ -13,6 +13,7 @@ using namespace paddle::custom_trainer::feed;
DEFINE_string(feed_trainer_conf_path, "./conf/trainer.yaml", "path of trainer conf"); DEFINE_string(feed_trainer_conf_path, "./conf/trainer.yaml", "path of trainer conf");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
google::InitGoogleLogging(argv[0]);
//gflags //gflags
google::ParseCommandLineFlags(&argc, &argv, true); google::ParseCommandLineFlags(&argc, &argv, true);
std::string gflag_conf = "./conf/gflags.conf"; std::string gflag_conf = "./conf/gflags.conf";
......
...@@ -76,7 +76,7 @@ int LearnerProcess::run() { ...@@ -76,7 +76,7 @@ int LearnerProcess::run() {
uint64_t epoch_id = epoch_accessor->current_epoch_id(); uint64_t epoch_id = epoch_accessor->current_epoch_id();
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Resume traine with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str()); "Resume trainer with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
//判断是否先dump出base //判断是否先dump出base
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
......
#!bash #!/bin/bash
export LD_LIBRARY_PATH=LD_LIBRARY_PATH:./so export LD_LIBRARY_PATH=LD_LIBRARY_PATH:./so
./bin/feed_trainer ./bin/feed_trainer "$@"
...@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) { ...@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::unique_ptr<paddle::framework::ProgramDesc> Load( std::unique_ptr<paddle::framework::ProgramDesc> Load(
paddle::framework::Executor* executor, const std::string& model_filename) { paddle::framework::Executor* executor, const std::string& model_filename) {
VLOG(3) << "loading model from " << model_filename; LOG(DEBUG) << "loading model from " << model_filename;
std::string program_desc_str; std::string program_desc_str;
ReadBinaryFile(model_filename, &program_desc_str); ReadBinaryFile(model_filename, &program_desc_str);
......
...@@ -133,9 +133,9 @@ TEST_F(CreateProgramsTest, example_network) { ...@@ -133,9 +133,9 @@ TEST_F(CreateProgramsTest, example_network) {
auto output_var = executor->var<::paddle::framework::LoDTensor>(output_name); auto output_var = executor->var<::paddle::framework::LoDTensor>(output_name);
auto output = output_var.data<float>()[0]; auto output = output_var.data<float>()[0];
VLOG(3) << "loss: " << loss << std::endl; LOG(INFO) << "loss: " << loss << std::endl;
VLOG(3) << "label: " << label_data[0] << std::endl; LOG(INFO) << "label: " << label_data[0] << std::endl;
VLOG(3) << "output: " << output << std::endl; LOG(INFO) << "output: " << output << std::endl;
ASSERT_NEAR(loss, pow(output - label_data[0], 2), 1e-8); ASSERT_NEAR(loss, pow(output - label_data[0], 2), 1e-8);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册