提交 38175506 编写于 作者: X xiexionghang

fix shuffler bug

......@@ -161,7 +161,12 @@ public:
protected:
virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const std::string& log_str) {
if (type == EnvironmentLogType::MASTER_LOG && !is_master_node(role)) {
if (type == EnvironmentLogType::MASTER_LOG) {
if (is_master_node(role)) {
fprintf(stdout, log_str.c_str());
fprintf(stdout, "\n");
fflush(stdout);
}
return;
}
VLOG(static_cast<int>(level)) << log_str;
......
......@@ -75,8 +75,8 @@ public:
// 环境定制化log
template<class... ARGS>
void log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const char* fmt, ARGS && ... args) {
print_log(role, type, level, paddle::string::format_string(fmt, args...));
EnvironmentLogLevel level, ARGS && ... args) {
print_log(role, type, level, paddle::string::format_string(args...));
}
// 多线程可调用接口 End
......@@ -106,14 +106,14 @@ protected:
};
REGIST_REGISTERER(RuntimeEnvironment);
#define ENVLOG_WORKER_ALL_NOTICE \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::ALL_LOG, EnvironmentLogType::NOTICE,
#define ENVLOG_WORKER_MASTER_NOTICE \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogType::NOTICE,
#define ENVLOG_WORKER_ALL_ERROR \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::ALL_LOG, EnvironmentLogType::ERROR,
#define ENVLOG_WORKER_MASTER_ERROR \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogType::ERROR,
#define ENVLOG_WORKER_ALL_NOTICE(...) \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::ALL_LOG, EnvironmentLogLevel::NOTICE, __VA_ARGS__);
#define ENVLOG_WORKER_MASTER_NOTICE(...) \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, __VA_ARGS__);
#define ENVLOG_WORKER_ALL_ERROR(...) \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::ALL_LOG, EnvironmentLogLevel::ERROR, __VA_ARGS__);
#define ENVLOG_WORKER_MASTER_ERROR(...) \
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::ERROR, __VA_ARGS__);
std::string format_timestamp(time_t time, const char* format);
inline std::string format_timestamp(time_t time, const std::string& format) {
......
......@@ -137,7 +137,7 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
paddle::framework::Channel<DataItem> input, const DataParser* parser) {
uint64_t epoch_id = _trainer_context->epoch_accessor->current_epoch_id();
auto* environment = _trainer_context->environment.get();
// 输入流
PipelineOptions input_pipe_option;
input_pipe_option.need_hold_input_data = true;
......@@ -243,8 +243,8 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
for (auto& monitor : _monitors) {
if (monitor->need_compute_result(epoch_id)) {
monitor->compute_result();
VLOG(2) << "[Monitor]" << _train_exe_name << ", monitor:" << monitor->get_name()
<< ", result:" << monitor->format_result();
ENVLOG_WORKER_MASTER_NOTICE("[Monitor]%s, monitor:%s, result:%s",
_train_exe_name.c_str(), monitor->get_name().c_str(), monitor->format_result().c_str());
_trainer_context->monitor_ssm << _train_exe_name << ":" <<
monitor->get_name() << ":" << monitor->format_result() << ",";
monitor->reset();
......
......@@ -46,7 +46,7 @@ private:
uint64_t _total_time_ms;
uint64_t _total_cnt;
uint64_t _avg_time_ms;
uint32_t _compute_interval;
uint32_t _compute_interval;
};
} // namespace feed
......
......@@ -5,6 +5,7 @@
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace paddle {
namespace custom_trainer {
......
......@@ -155,10 +155,11 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
if (!environment->is_master_node(EnvironmentRole::WORKER)) {
return 0;
}
VLOG(2) << "Start Load Model";
auto* fs = _context_ptr->file_system.get();
std::set<uint32_t> loaded_table_set;
auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
paddle::platform::Timer timer;
timer.Start();
for (auto& executor : _executors) {
const auto& table_accessors = executor->table_accessors();
for (auto& itr : table_accessors) {
......@@ -172,6 +173,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
auto scope = std::move(executor->fetch_scope());
CHECK(itr.second[0]->create(scope.get()) == 0);
} else {
ENVLOG_WORKER_MASTER_NOTICE("Loading model %s", model_dir.c_str());
auto status = _context_ptr->ps_client()->load(itr.first,
model_dir, std::to_string((int)ModelSaveWay::ModelSaveTrainCheckpoint));
CHECK(status.get() == 0) << "table load failed, id:" << itr.first;
......@@ -179,7 +181,8 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
loaded_table_set.insert(itr.first);
}
}
VLOG(2) << "Finish Load Model";
timer.Pause();
ENVLOG_WORKER_MASTER_NOTICE("Finished loading model, cost:%f", timer.ElapsedSec());
return 0;
}
......@@ -189,9 +192,7 @@ int LearnerProcess::run() {
auto* epoch_accessor = _context_ptr->epoch_accessor.get();
uint64_t epoch_id = epoch_accessor->current_epoch_id();
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
ENVLOG_WORKER_MASTER_NOTICE("Resume train with epoch_id:%d %s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
//尝试加载模型 or 初始化
CHECK(load_model(epoch_id) == 0);
environment->barrier(EnvironmentRole::WORKER);
......@@ -208,19 +209,16 @@ int LearnerProcess::run() {
std::string epoch_log_title = paddle::string::format_string(
"train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
std::string data_path = paddle::string::to_string<std::string>(dataset->epoch_data_path(epoch_id));
ENVLOG_WORKER_MASTER_NOTICE(" ==== begin %s ====", epoch_accessor->text(epoch_id).c_str());
//Step1. 等待样本ready
{
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
ENVLOG_WORKER_MASTER_NOTICE(" %s, wait data ready:%s", epoch_log_title.c_str(), data_path.c_str());
while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
sleep(30);
dataset->pre_detect_data(epoch_id);
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"data not ready, wait 30s");
ENVLOG_WORKER_MASTER_NOTICE(" epoch_id:%d data not ready, wait 30s", epoch_id);
}
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Start %s, data is ready", epoch_log_title.c_str());
ENVLOG_WORKER_MASTER_NOTICE(" Start %s, data is ready", epoch_log_title.c_str());
environment->barrier(EnvironmentRole::WORKER);
}
......@@ -232,7 +230,7 @@ int LearnerProcess::run() {
environment->barrier(EnvironmentRole::WORKER);
paddle::platform::Timer timer;
timer.Start();
VLOG(2) << "Start executor:" << executor->train_exe_name();
ENVLOG_WORKER_MASTER_NOTICE("Start executor:%s", executor->train_exe_name().c_str());
auto data_name = executor->train_data_name();
paddle::framework::Channel<DataItem> input_channel;
if (backup_input_map.count(data_name)) {
......@@ -242,7 +240,7 @@ int LearnerProcess::run() {
}
input_channel = executor->run(input_channel, dataset->data_parser(data_name));
timer.Pause();
VLOG(2) << "End executor:" << executor->train_exe_name() << ", cost:" << timer.ElapsedSec();
ENVLOG_WORKER_MASTER_NOTICE("End executor:%s, cost:%f", executor->train_exe_name().c_str(), timer.ElapsedSec());
// 等待异步梯度完成
_context_ptr->ps_client()->flush();
......@@ -273,21 +271,22 @@ int LearnerProcess::run() {
environment->is_master_node(EnvironmentRole::WORKER)) {
paddle::platform::Timer timer;
timer.Start();
VLOG(2) << "Start shrink table";
ENVLOG_WORKER_MASTER_NOTICE("Start shrink table");
for (auto& executor : _executors) {
const auto& table_accessors = executor->table_accessors();
for (auto& itr : table_accessors) {
CHECK(itr.second[0]->shrink() == 0);
}
}
VLOG(2) << "End shrink table, cost:" << timer.ElapsedSec();
timer.Pause();
ENVLOG_WORKER_MASTER_NOTICE("End shrink table, cost:%f", timer.ElapsedSec());
}
environment->barrier(EnvironmentRole::WORKER);
epoch_accessor->epoch_done(epoch_id);
environment->barrier(EnvironmentRole::WORKER);
}
ENVLOG_WORKER_MASTER_NOTICE(" ==== end %s ====", epoch_accessor->text(epoch_id).c_str());
//Step4. Output Monitor && RunStatus
//TODO
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册