提交 38175506 编写于 作者: X xiexionghang

fix shuffler bug

...@@ -161,7 +161,12 @@ public: ...@@ -161,7 +161,12 @@ public:
protected: protected:
virtual void print_log(EnvironmentRole role, EnvironmentLogType type, virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const std::string& log_str) { EnvironmentLogLevel level, const std::string& log_str) {
if (type == EnvironmentLogType::MASTER_LOG && !is_master_node(role)) { if (type == EnvironmentLogType::MASTER_LOG) {
if (is_master_node(role)) {
fprintf(stdout, log_str.c_str());
fprintf(stdout, "\n");
fflush(stdout);
}
return; return;
} }
VLOG(static_cast<int>(level)) << log_str; VLOG(static_cast<int>(level)) << log_str;
......
...@@ -75,8 +75,8 @@ public: ...@@ -75,8 +75,8 @@ public:
// 环境定制化log // 环境定制化log
template<class... ARGS> template<class... ARGS>
void log(EnvironmentRole role, EnvironmentLogType type, void log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const char* fmt, ARGS && ... args) { EnvironmentLogLevel level, ARGS && ... args) {
print_log(role, type, level, paddle::string::format_string(fmt, args...)); print_log(role, type, level, paddle::string::format_string(args...));
} }
// 多线程可调用接口 End // 多线程可调用接口 End
...@@ -106,14 +106,14 @@ protected: ...@@ -106,14 +106,14 @@ protected:
}; };
REGIST_REGISTERER(RuntimeEnvironment); REGIST_REGISTERER(RuntimeEnvironment);
#define ENVLOG_WORKER_ALL_NOTICE \ #define ENVLOG_WORKER_ALL_NOTICE(...) \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::ALL_LOG, EnvironmentLogType::NOTICE, environment->log(EnvironmentRole::WORKER, EnvironmentLogType::ALL_LOG, EnvironmentLogLevel::NOTICE, __VA_ARGS__);
#define ENVLOG_WORKER_MASTER_NOTICE \ #define ENVLOG_WORKER_MASTER_NOTICE(...) \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogType::NOTICE, environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, __VA_ARGS__);
#define ENVLOG_WORKER_ALL_ERROR \ #define ENVLOG_WORKER_ALL_ERROR(...) \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::ALL_LOG, EnvironmentLogType::ERROR, environment->log(EnvironmentRole::WORKER, EnvironmentLogType::ALL_LOG, EnvironmentLogLevel::ERROR, __VA_ARGS__);
#define ENVLOG_WORKER_MASTER_ERROR \ #define ENVLOG_WORKER_MASTER_ERROR(...) \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogType::ERROR, environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::ERROR, __VA_ARGS__);
std::string format_timestamp(time_t time, const char* format); std::string format_timestamp(time_t time, const char* format);
inline std::string format_timestamp(time_t time, const std::string& format) { inline std::string format_timestamp(time_t time, const std::string& format) {
......
...@@ -137,7 +137,7 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run( ...@@ -137,7 +137,7 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
paddle::framework::Channel<DataItem> input, const DataParser* parser) { paddle::framework::Channel<DataItem> input, const DataParser* parser) {
uint64_t epoch_id = _trainer_context->epoch_accessor->current_epoch_id(); uint64_t epoch_id = _trainer_context->epoch_accessor->current_epoch_id();
auto* environment = _trainer_context->environment.get();
// 输入流 // 输入流
PipelineOptions input_pipe_option; PipelineOptions input_pipe_option;
input_pipe_option.need_hold_input_data = true; input_pipe_option.need_hold_input_data = true;
...@@ -243,8 +243,8 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run( ...@@ -243,8 +243,8 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
for (auto& monitor : _monitors) { for (auto& monitor : _monitors) {
if (monitor->need_compute_result(epoch_id)) { if (monitor->need_compute_result(epoch_id)) {
monitor->compute_result(); monitor->compute_result();
VLOG(2) << "[Monitor]" << _train_exe_name << ", monitor:" << monitor->get_name() ENVLOG_WORKER_MASTER_NOTICE("[Monitor]%s, monitor:%s, result:%s",
<< ", result:" << monitor->format_result(); _train_exe_name.c_str(), monitor->get_name().c_str(), monitor->format_result().c_str());
_trainer_context->monitor_ssm << _train_exe_name << ":" << _trainer_context->monitor_ssm << _train_exe_name << ":" <<
monitor->get_name() << ":" << monitor->format_result() << ","; monitor->get_name() << ":" << monitor->format_result() << ",";
monitor->reset(); monitor->reset();
......
...@@ -46,7 +46,7 @@ private: ...@@ -46,7 +46,7 @@ private:
uint64_t _total_time_ms; uint64_t _total_time_ms;
uint64_t _total_cnt; uint64_t _total_cnt;
uint64_t _avg_time_ms; uint64_t _avg_time_ms;
uint32_t _compute_interval; uint32_t _compute_interval;
}; };
} // namespace feed } // namespace feed
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h" #include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace paddle { namespace paddle {
namespace custom_trainer { namespace custom_trainer {
......
...@@ -155,10 +155,11 @@ int LearnerProcess::load_model(uint64_t epoch_id) { ...@@ -155,10 +155,11 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
if (!environment->is_master_node(EnvironmentRole::WORKER)) { if (!environment->is_master_node(EnvironmentRole::WORKER)) {
return 0; return 0;
} }
VLOG(2) << "Start Load Model";
auto* fs = _context_ptr->file_system.get(); auto* fs = _context_ptr->file_system.get();
std::set<uint32_t> loaded_table_set; std::set<uint32_t> loaded_table_set;
auto model_dir = _context_ptr->epoch_accessor->checkpoint_path(); auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
paddle::platform::Timer timer;
timer.Start();
for (auto& executor : _executors) { for (auto& executor : _executors) {
const auto& table_accessors = executor->table_accessors(); const auto& table_accessors = executor->table_accessors();
for (auto& itr : table_accessors) { for (auto& itr : table_accessors) {
...@@ -172,6 +173,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) { ...@@ -172,6 +173,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
auto scope = std::move(executor->fetch_scope()); auto scope = std::move(executor->fetch_scope());
CHECK(itr.second[0]->create(scope.get()) == 0); CHECK(itr.second[0]->create(scope.get()) == 0);
} else { } else {
ENVLOG_WORKER_MASTER_NOTICE("Loading model %s", model_dir.c_str());
auto status = _context_ptr->ps_client()->load(itr.first, auto status = _context_ptr->ps_client()->load(itr.first,
model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint)); model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
CHECK(status.get() == 0) << "table load failed, id:" << itr.first; CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
...@@ -179,7 +181,8 @@ int LearnerProcess::load_model(uint64_t epoch_id) { ...@@ -179,7 +181,8 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
loaded_table_set.insert(itr.first); loaded_table_set.insert(itr.first);
} }
} }
VLOG(2) << "Finish Load Model"; timer.Pause();
ENVLOG_WORKER_MASTER_NOTICE("Finished loading model, cost:%f", timer.ElapsedSec());
return 0; return 0;
} }
...@@ -189,9 +192,7 @@ int LearnerProcess::run() { ...@@ -189,9 +192,7 @@ int LearnerProcess::run() {
auto* epoch_accessor = _context_ptr->epoch_accessor.get(); auto* epoch_accessor = _context_ptr->epoch_accessor.get();
uint64_t epoch_id = epoch_accessor->current_epoch_id(); uint64_t epoch_id = epoch_accessor->current_epoch_id();
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, ENVLOG_WORKER_MASTER_NOTICE("Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
"Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
//尝试加载模型 or 初始化 //尝试加载模型 or 初始化
CHECK(load_model(epoch_id) == 0); CHECK(load_model(epoch_id) == 0);
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
...@@ -208,19 +209,16 @@ int LearnerProcess::run() { ...@@ -208,19 +209,16 @@ int LearnerProcess::run() {
std::string epoch_log_title = paddle::string::format_string( std::string epoch_log_title = paddle::string::format_string(
"train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str()); "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id)); std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
ENVLOG_WORKER_MASTER_NOTICE(" ==== begin %s ====", epoch_accessor->text(epoch_id).c_str());
//Step1. 等待样本ready //Step1. 等待样本ready
{ {
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, ENVLOG_WORKER_MASTER_NOTICE(" %s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
"%s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) { while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
sleep(30); sleep(30);
dataset->pre_detect_data(epoch_id); dataset->pre_detect_data(epoch_id);
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, ENVLOG_WORKER_MASTER_NOTICE(" epoch_id:%d data not ready, wait 30s", epoch_id);
"data not ready, wait 30s");
} }
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, ENVLOG_WORKER_MASTER_NOTICE(" Start %s, data is ready", epoch_log_title.c_str());
"Start %s, data is ready", epoch_log_title.c_str());
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
} }
...@@ -232,7 +230,7 @@ int LearnerProcess::run() { ...@@ -232,7 +230,7 @@ int LearnerProcess::run() {
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
paddle::platform::Timer timer; paddle::platform::Timer timer;
timer.Start(); timer.Start();
VLOG(2) << "Start executor:" << executor->train_exe_name(); ENVLOG_WORKER_MASTER_NOTICE("Start executor:%s", executor->train_exe_name().c_str());
auto data_name = executor->train_data_name(); auto data_name = executor->train_data_name();
paddle::framework::Channel<DataItem> input_channel; paddle::framework::Channel<DataItem> input_channel;
if (backup_input_map.count(data_name)) { if (backup_input_map.count(data_name)) {
...@@ -242,7 +240,7 @@ int LearnerProcess::run() { ...@@ -242,7 +240,7 @@ int LearnerProcess::run() {
} }
input_channel = executor->run(input_channel, dataset->data_parser(data_name)); input_channel = executor->run(input_channel, dataset->data_parser(data_name));
timer.Pause(); timer.Pause();
VLOG(2) << "End executor:" << executor->train_exe_name() << ", cost:" << timer.ElapsedSec(); ENVLOG_WORKER_MASTER_NOTICE("End executor:%s, cost:%f", executor->train_exe_name().c_str(), timer.ElapsedSec());
// 等待异步梯度完成 // 等待异步梯度完成
_context_ptr->ps_client()->flush(); _context_ptr->ps_client()->flush();
...@@ -273,21 +271,22 @@ int LearnerProcess::run() { ...@@ -273,21 +271,22 @@ int LearnerProcess::run() {
environment->is_master_node(EnvironmentRole::WORKER)) { environment->is_master_node(EnvironmentRole::WORKER)) {
paddle::platform::Timer timer; paddle::platform::Timer timer;
timer.Start(); timer.Start();
VLOG(2) << "Start shrink table"; ENVLOG_WORKER_MASTER_NOTICE("Start shrink table");
for (auto& executor : _executors) { for (auto& executor : _executors) {
const auto& table_accessors = executor->table_accessors(); const auto& table_accessors = executor->table_accessors();
for (auto& itr : table_accessors) { for (auto& itr : table_accessors) {
CHECK(itr.second[0]->shrink() == 0); CHECK(itr.second[0]->shrink() == 0);
} }
} }
VLOG(2) << "End shrink table, cost:" << timer.ElapsedSec(); timer.Pause();
ENVLOG_WORKER_MASTER_NOTICE("End shrink table, cost:%f", timer.ElapsedSec());
} }
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
epoch_accessor->epoch_done(epoch_id); epoch_accessor->epoch_done(epoch_id);
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
} }
ENVLOG_WORKER_MASTER_NOTICE(" ==== end %s ====", epoch_accessor->text(epoch_id).c_str());
//Step4. Output Monitor && RunStatus //Step4. Output Monitor && RunStatus
//TODO //TODO
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册