提交 b23282f3 编写于 作者: X xiexionghang

fix shuffler bug

上级 7654080c
...@@ -77,6 +77,7 @@ int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) { ...@@ -77,6 +77,7 @@ int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) {
return 0; return 0;
} }
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) { int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Saving);
auto fs = _context_ptr->file_system; auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client(); auto* ps_client = _context_ptr->pslib->ps_client();
auto* environment = _context_ptr->environment.get(); auto* environment = _context_ptr->environment.get();
...@@ -154,6 +155,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) { ...@@ -154,6 +155,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
if (!environment->is_master_node(EnvironmentRole::WORKER)) { if (!environment->is_master_node(EnvironmentRole::WORKER)) {
return 0; return 0;
} }
VLOG(2) << "Start Load Model";
auto* fs = _context_ptr->file_system.get(); auto* fs = _context_ptr->file_system.get();
std::set<uint32_t> loaded_table_set; std::set<uint32_t> loaded_table_set;
auto model_dir = _context_ptr->epoch_accessor->checkpoint_path(); auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
...@@ -177,6 +179,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) { ...@@ -177,6 +179,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
loaded_table_set.insert(itr.first); loaded_table_set.insert(itr.first);
} }
} }
VLOG(2) << "Finish Load Model";
return 0; return 0;
} }
...@@ -223,6 +226,7 @@ int LearnerProcess::run() { ...@@ -223,6 +226,7 @@ int LearnerProcess::run() {
//Step2. 运行训练网络 //Step2. 运行训练网络
{ {
ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Training);
std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map; std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
for (auto& executor : _executors) { for (auto& executor : _executors) {
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
......
...@@ -39,8 +39,8 @@ public: ...@@ -39,8 +39,8 @@ public:
virtual int initialize(YAML::Node config, virtual int initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr) { std::shared_ptr<TrainerContext> context_ptr) {
Shuffler::initialize(config, context_ptr); Shuffler::initialize(config, context_ptr);
_max_concurrent_num = config["max_concurrent_num"].as<int>(4); // 最大并发发送数 _max_concurrent_num = config["max_concurrent_num"].as<int>(6); // 最大并发发送数
_max_package_size = config["max_package_size"].as<int>(1024); // 最大包个数,一次发送package个数据 _max_package_size = config["max_package_size"].as<int>(256); // 最大包个数,一次发送package个数据
_shuffle_data_msg_type = config["shuffle_data_msg_type"].as<int>(3); // c2c msg type _shuffle_data_msg_type = config["shuffle_data_msg_type"].as<int>(3); // c2c msg type
_finish_msg_type = config["finish_msg_type"].as<int>(4); // c2c msg type _finish_msg_type = config["finish_msg_type"].as<int>(4); // c2c msg type
...@@ -62,6 +62,8 @@ public: ...@@ -62,6 +62,8 @@ public:
data_channel.swap(input_channel); data_channel.swap(input_channel);
set_channel(data_channel); set_channel(data_channel);
_item_send_count = 0;
_item_receive_count = 0;
auto* environment = _trainer_context->environment.get(); auto* environment = _trainer_context->environment.get();
auto worker_num = environment->node_num(EnvironmentRole::WORKER); auto worker_num = environment->node_num(EnvironmentRole::WORKER);
std::vector<std::vector<std::future<int>>> waits(concurrent_num); std::vector<std::vector<std::future<int>>> waits(concurrent_num);
...@@ -86,8 +88,9 @@ public: ...@@ -86,8 +88,9 @@ public:
status = 1; status = 1;
break; break;
} }
_item_send_count += read_size;
for (int i = 0; i < worker_num; ++i) { for (int i = 0; i < worker_num; ++i) {
send_buffer_worker.clear(); send_buffer_worker[i].clear();
} }
for (int i = 0; i < read_size; ++i) { for (int i = 0; i < read_size; ++i) {
auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num; auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num;
...@@ -119,19 +122,19 @@ public: ...@@ -119,19 +122,19 @@ public:
} }
} }
} }
VLOG(5) << "start send finish, worker_num: " << worker_num; VLOG(2) << "start send finish, worker_num: " << worker_num;
waits[0].clear(); waits[0].clear();
for (int i = 0; i < worker_num; ++i) { for (int i = 0; i < worker_num; ++i) {
waits[0].push_back(send_finish(i)); waits[0].push_back(send_finish(i));
} }
VLOG(5) << "wait all finish"; VLOG(2) << "wait all finish";
for (int i = 0; i < worker_num; ++i) { for (int i = 0; i < worker_num; ++i) {
if (waits[0][i].get() != 0) { if (waits[0][i].get() != 0) {
LOG(WARNING) << "fail to send finish " << i; LOG(WARNING) << "fail to send finish " << i;
status = -1; status = -1;
} }
} }
VLOG(5) << "finish shuffler, status: " << status; VLOG(2) << "finish shuffler_send_channel, total_send:" << _item_send_count;
return status < 0 ? status : 0; return status < 0 ? status : 0;
} }
...@@ -174,6 +177,7 @@ private: ...@@ -174,6 +177,7 @@ private:
// 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel // 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel
if (wait_num == 0) { if (wait_num == 0) {
reset_channel(); reset_channel();
VLOG(2) << "finish shuffle_receive_channel, receive_count: " << _item_receive_count;
_wait_num_mutex.unlock(); _wait_num_mutex.unlock();
} else { } else {
std::lock_guard<bthread::Mutex> lock(_wait_num_mutex); std::lock_guard<bthread::Mutex> lock(_wait_num_mutex);
...@@ -182,7 +186,7 @@ private: ...@@ -182,7 +186,7 @@ private:
} }
int32_t write_to_channel(std::vector<DataItem>&& items) { int32_t write_to_channel(std::vector<DataItem>&& items) {
size_t items_size = items.size(); size_t items_size = items.size();
VLOG(5) << "write_to_channel, items_size: " << items_size; _item_receive_count += items_size;
return _out_channel->Write(std::move(items)) == items_size ? 0 : -1; return _out_channel->Write(std::move(items)) == items_size ? 0 : -1;
} }
...@@ -207,6 +211,8 @@ private: ...@@ -207,6 +211,8 @@ private:
} }
std::future<int32_t> send_shuffle_data(int to_client_id, std::vector<DataItem>& items) { std::future<int32_t> send_shuffle_data(int to_client_id, std::vector<DataItem>& items) {
// server端也开启了client, worker节点为偶数编号
to_client_id = 2 * to_client_id;
VLOG(5) << "send_shuffle_data, to_client_id: " << to_client_id << ", items_size: " << items.size(); VLOG(5) << "send_shuffle_data, to_client_id: " << to_client_id << ", items_size: " << items.size();
paddle::framework::BinaryArchive ar; paddle::framework::BinaryArchive ar;
ar << items; ar << items;
...@@ -215,6 +221,8 @@ private: ...@@ -215,6 +221,8 @@ private:
} }
std::future<int32_t> send_finish(int to_client_id) { std::future<int32_t> send_finish(int to_client_id) {
// server端也开启了client, worker节点为偶数编号
to_client_id = 2 * to_client_id;
VLOG(5) << "send_finish, to_client_id: " << to_client_id; VLOG(5) << "send_finish, to_client_id: " << to_client_id;
static const std::string empty_str; static const std::string empty_str;
return _trainer_context->pslib->ps_client()->send_client2client_msg(_finish_msg_type, to_client_id, empty_str); return _trainer_context->pslib->ps_client()->send_client2client_msg(_finish_msg_type, to_client_id, empty_str);
...@@ -230,6 +238,8 @@ private: ...@@ -230,6 +238,8 @@ private:
bthread::Mutex _wait_num_mutex; bthread::Mutex _wait_num_mutex;
std::atomic<int> _wait_num; std::atomic<int> _wait_num;
std::atomic<uint32_t> _item_send_count;
std::atomic<uint32_t> _item_receive_count;
}; };
REGIST_CLASS(Shuffler, GlobalShuffler); REGIST_CLASS(Shuffler, GlobalShuffler);
......
...@@ -95,20 +95,23 @@ private: ...@@ -95,20 +95,23 @@ private:
class TrainerContext { class TrainerContext {
public: public:
TrainerContext() {
trainer_status.resize(2, 0);
}
inline paddle::ps::PSClient* ps_client() { inline paddle::ps::PSClient* ps_client() {
return pslib->ps_client(); return pslib->ps_client();
} }
inline bool is_status(TrainerStatus status) { inline bool is_status(TrainerStatus status) {
auto bit_idx = static_cast<uint32_t>(status); auto status_idx = static_cast<uint32_t>(status);
return ((trainer_status >> bit_idx) & 1) > 0; return trainer_status[status_idx] > 0;
} }
// 非线程安全, 其实TrainerContext所有成员的线程安全性 取决于 成员本身的线程安全性 // 非线程安全, 其实TrainerContext所有成员的线程安全性 取决于 成员本身的线程安全性
inline void set_status(TrainerStatus status, bool on) { inline void set_status(TrainerStatus status, bool on) {
auto bit_idx = static_cast<uint32_t>(status); auto status_idx = static_cast<uint32_t>(status);
trainer_status = trainer_status & (1L << bit_idx); trainer_status[status_idx] = on ? 1 : 0;
} }
uint32_t trainer_status; // trainer当前,由于可同时处于多种状态,这里分bit存储状态 std::vector<uint32_t> trainer_status;
YAML::Node trainer_config; YAML::Node trainer_config;
paddle::platform::CPUPlace cpu_place; paddle::platform::CPUPlace cpu_place;
...@@ -122,6 +125,20 @@ public: ...@@ -122,6 +125,20 @@ public:
std::shared_ptr<SignCacheDict> cache_dict; //大模型cache词典 std::shared_ptr<SignCacheDict> cache_dict; //大模型cache词典
}; };
class ContextStatusGurad {
public:
ContextStatusGurad(TrainerContext* context, TrainerStatus status) :
_context(context), _status(status) {
_context->set_status(_status, true);
}
virtual ~ContextStatusGurad() {
_context->set_status(_status, false);
}
private:
TrainerStatus _status;
TrainerContext* _context = nullptr;
};
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册