提交 b23282f3 编写于 作者: X xiexionghang

fix shuffler bug

上级 7654080c
......@@ -77,6 +77,7 @@ int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) {
return 0;
}
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Saving);
auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client();
auto* environment = _context_ptr->environment.get();
......@@ -154,6 +155,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
if (!environment->is_master_node(EnvironmentRole::WORKER)) {
return 0;
}
VLOG(2) << "Start Load Model";
auto* fs = _context_ptr->file_system.get();
std::set<uint32_t> loaded_table_set;
auto model_dir = _context_ptr->epoch_accessor->checkpoint_path();
......@@ -177,6 +179,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
loaded_table_set.insert(itr.first);
}
}
VLOG(2) << "Finish Load Model";
return 0;
}
......@@ -223,6 +226,7 @@ int LearnerProcess::run() {
//Step2. 运行训练网络
{
ContextStatusGurad status_guard(_context_ptr, TrainerStatus::Training);
std::map<std::string, paddle::framework::Channel<DataItem>> backup_input_map;
for (auto& executor : _executors) {
environment->barrier(EnvironmentRole::WORKER);
......
......@@ -39,8 +39,8 @@ public:
virtual int initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr) {
Shuffler::initialize(config, context_ptr);
_max_concurrent_num = config["max_concurrent_num"].as<int>(4); // 最大并发发送数
_max_package_size = config["max_package_size"].as<int>(1024); // 最大包个数,一次发送package个数据
_max_concurrent_num = config["max_concurrent_num"].as<int>(6); // 最大并发发送数
_max_package_size = config["max_package_size"].as<int>(256); // 最大包个数,一次发送package个数据
_shuffle_data_msg_type = config["shuffle_data_msg_type"].as<int>(3); // c2c msg type
_finish_msg_type = config["finish_msg_type"].as<int>(4); // c2c msg type
......@@ -62,6 +62,8 @@ public:
data_channel.swap(input_channel);
set_channel(data_channel);
_item_send_count = 0;
_item_receive_count = 0;
auto* environment = _trainer_context->environment.get();
auto worker_num = environment->node_num(EnvironmentRole::WORKER);
std::vector<std::vector<std::future<int>>> waits(concurrent_num);
......@@ -86,8 +88,9 @@ public:
status = 1;
break;
}
_item_send_count += read_size;
for (int i = 0; i < worker_num; ++i) {
send_buffer_worker.clear();
send_buffer_worker[i].clear();
}
for (int i = 0; i < read_size; ++i) {
auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num;
......@@ -119,19 +122,19 @@ public:
}
}
}
VLOG(5) << "start send finish, worker_num: " << worker_num;
VLOG(2) << "start send finish, worker_num: " << worker_num;
waits[0].clear();
for (int i = 0; i < worker_num; ++i) {
waits[0].push_back(send_finish(i));
}
VLOG(5) << "wait all finish";
VLOG(2) << "wait all finish";
for (int i = 0; i < worker_num; ++i) {
if (waits[0][i].get() != 0) {
LOG(WARNING) << "fail to send finish " << i;
status = -1;
}
}
VLOG(5) << "finish shuffler, status: " << status;
VLOG(2) << "finish shuffler_send_channel, total_send:" << _item_send_count;
return status < 0 ? status : 0;
}
......@@ -174,6 +177,7 @@ private:
// 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel
if (wait_num == 0) {
reset_channel();
VLOG(2) << "finish shuffle_receive_channel, receive_count: " << _item_receive_count;
_wait_num_mutex.unlock();
} else {
std::lock_guard<bthread::Mutex> lock(_wait_num_mutex);
......@@ -182,7 +186,7 @@ private:
}
int32_t write_to_channel(std::vector<DataItem>&& items) {
size_t items_size = items.size();
VLOG(5) << "write_to_channel, items_size: " << items_size;
_item_receive_count += items_size;
return _out_channel->Write(std::move(items)) == items_size ? 0 : -1;
}
......@@ -207,6 +211,8 @@ private:
}
std::future<int32_t> send_shuffle_data(int to_client_id, std::vector<DataItem>& items) {
// server端也开启了client, worker节点为偶数编号
to_client_id = 2 * to_client_id;
VLOG(5) << "send_shuffle_data, to_client_id: " << to_client_id << ", items_size: " << items.size();
paddle::framework::BinaryArchive ar;
ar << items;
......@@ -215,6 +221,8 @@ private:
}
std::future<int32_t> send_finish(int to_client_id) {
// server端也开启了client, worker节点为偶数编号
to_client_id = 2 * to_client_id;
VLOG(5) << "send_finish, to_client_id: " << to_client_id;
static const std::string empty_str;
return _trainer_context->pslib->ps_client()->send_client2client_msg(_finish_msg_type, to_client_id, empty_str);
......@@ -230,6 +238,8 @@ private:
bthread::Mutex _wait_num_mutex;
std::atomic<int> _wait_num;
std::atomic<uint32_t> _item_send_count;
std::atomic<uint32_t> _item_receive_count;
};
REGIST_CLASS(Shuffler, GlobalShuffler);
......
......@@ -95,20 +95,23 @@ private:
class TrainerContext {
public:
TrainerContext() {
trainer_status.resize(2, 0);
}
inline paddle::ps::PSClient* ps_client() {
return pslib->ps_client();
}
inline bool is_status(TrainerStatus status) {
auto bit_idx = static_cast<uint32_t>(status);
return ((trainer_status >> bit_idx) & 1) > 0;
auto status_idx = static_cast<uint32_t>(status);
return trainer_status[status_idx] > 0;
}
// 非线程安全, 其实TrainerContext所有成员的线程安全性 取决于 成员本身的线程安全性
inline void set_status(TrainerStatus status, bool on) {
auto bit_idx = static_cast<uint32_t>(status);
trainer_status = trainer_status & (1L << bit_idx);
auto status_idx = static_cast<uint32_t>(status);
trainer_status[status_idx] = on ? 1 : 0;
}
uint32_t trainer_status; // trainer当前,由于可同时处于多种状态,这里分bit存储状态
std::vector<uint32_t> trainer_status;
YAML::Node trainer_config;
paddle::platform::CPUPlace cpu_place;
......@@ -122,6 +125,20 @@ public:
std::shared_ptr<SignCacheDict> cache_dict; //大模型cache词典
};
class ContextStatusGurad {
public:
ContextStatusGurad(TrainerContext* context, TrainerStatus status) :
_context(context), _status(status) {
_context->set_status(_status, true);
}
virtual ~ContextStatusGurad() {
_context->set_status(_status, false);
}
private:
TrainerStatus _status;
TrainerContext* _context = nullptr;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册