提交 162d5fc7 编写于 作者: R rensilin

global_shuffle

Change-Id: I431825fe28853a81151febcd0a45fd77584a6e2a
上级 229964e4
...@@ -68,6 +68,16 @@ public: ...@@ -68,6 +68,16 @@ public:
std::string data;//样本数据, maybe压缩格式 std::string data;//样本数据, maybe压缩格式
}; };
template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar, DataItem& x) {
return ar >> x.id >> x.data;
}
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar, const DataItem& x) {
return ar << x.id << x.data;
}
typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe; typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe;
inline SampleInstancePipe make_sample_instance_channel() { inline SampleInstancePipe make_sample_instance_channel() {
return std::make_shared<Pipeline<DataItem, SampleInstance>>(); return std::make_shared<Pipeline<DataItem, SampleInstance>>();
......
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h" #include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h"
#include <bthread/butex.h>
namespace paddle { namespace paddle {
namespace custom_trainer { namespace custom_trainer {
...@@ -40,73 +41,195 @@ public: ...@@ -40,73 +41,195 @@ public:
Shuffler::initialize(config, context_ptr); Shuffler::initialize(config, context_ptr);
_max_concurrent_num = config["max_concurrent_num"].as<int>(4); // 最大并发发送数 _max_concurrent_num = config["max_concurrent_num"].as<int>(4); // 最大并发发送数
_max_package_size = config["max_package_size"].as<int>(1024); // 最大包个数,一次发送package个数据 _max_package_size = config["max_package_size"].as<int>(1024); // 最大包个数,一次发送package个数据
_shuffle_data_msg_type = config["shuffle_data_msg_type"].as<int>(3); // c2c msg type
_finish_msg_type = config["finish_msg_type"].as<int>(4); // c2c msg type
reset_channel();
auto binded = std::bind(&GlobalShuffler::get_client2client_msg, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3);
_trainer_context->pslib->ps_client()->registe_client2client_msg_handler(_shuffle_data_msg_type,
binded);
_trainer_context->pslib->ps_client()->registe_client2client_msg_handler(_finish_msg_type,
binded);
return 0; return 0;
} }
// 所有worker必须都调用shuffle,并且shuffler同时只能有一个shuffle任务
virtual int shuffle(::paddle::framework::Channel<DataItem>& data_channel) { virtual int shuffle(::paddle::framework::Channel<DataItem>& data_channel) {
uint32_t send_count = 0; uint32_t send_count = 0;
uint32_t package_size = _max_package_size; uint32_t package_size = _max_package_size;
uint32_t concurrent_num = _max_concurrent_num; uint32_t concurrent_num = _max_concurrent_num;
uint32_t current_wait_idx = 0; ::paddle::framework::Channel<DataItem> input_channel = ::paddle::framework::MakeChannel<DataItem>(data_channel);
data_channel.swap(input_channel);
set_channel(data_channel);
auto* environment = _trainer_context->environment.get(); auto* environment = _trainer_context->environment.get();
auto worker_num = environment->node_num(EnvironmentRole::WORKER); auto worker_num = environment->node_num(EnvironmentRole::WORKER);
std::vector<std::vector<std::future<int>>> waits(concurrent_num); std::vector<std::vector<std::future<int>>> waits(concurrent_num);
std::vector<DataItem> send_buffer(concurrent_num * package_size); std::vector<DataItem> send_buffer(package_size);
std::vector<paddle::framework::BinaryArchive> request_data_buffer(worker_num); std::vector<std::vector<DataItem>> send_buffer_worker(worker_num);
while (true) {
auto read_size = data_channel->Read(concurrent_num * package_size, &send_buffer[0]); int status = 0;// >0: finish; =0: running; <0: fail
if (read_size == 0) { while (status == 0) {
break; // update status
// 如果在训练期,则限速shuffle
// 如果在wait状态,全速shuffle
if (_trainer_context->is_status(TrainerStatus::Training)) {
concurrent_num = 1;
package_size = _max_concurrent_num / 2;
} else {
package_size = _max_package_size;
concurrent_num = _max_concurrent_num;
} }
for (size_t idx = 0; idx < read_size; idx += package_size) { for (uint32_t current_wait_idx = 0; status == 0 && current_wait_idx < concurrent_num; ++current_wait_idx) {
// data shard && seriliaze auto read_size = input_channel->Read(package_size, send_buffer.data());
for (size_t i = 0; i < worker_num; ++i) { if (read_size == 0) {
request_data_buffer[i].Clear(); status = 1;
break;
} }
for (size_t i = idx; i < package_size && i < read_size; ++i) { for (int i = 0; i < worker_num; ++i) {
auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num; send_buffer_worker.clear();
// TODO Serialize To Arcive
//request_data_buffer[worker_idx] << send_buffer[i];
} }
std::string data_vec[worker_num]; for (int i = 0; i < read_size; ++i) {
for (size_t i = 0; i < worker_num; ++i) { auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num;
auto& buffer = request_data_buffer[i]; send_buffer_worker[worker_idx].push_back(std::move(send_buffer[i]));
data_vec[i].assign(buffer.Buffer(), buffer.Length());
} }
// wait async done
for (auto& wait_s : waits[current_wait_idx]) { for (auto& wait_s : waits[current_wait_idx]) {
if (!wait_s.valid()) { if (wait_s.get() != 0) {
LOG(WARNING) << "fail to send shuffle data";
status = -1;
break; break;
} }
CHECK(wait_s.get() == 0);
} }
if (status != 0) {
// send shuffle data break;
for (size_t i = 0; i < worker_num; ++i) { }
waits[current_wait_idx][i] = _trainer_context->pslib->ps_client()->send_client2client_msg(3, i * 2, data_vec[i]); waits[current_wait_idx].clear();
for (int i = 0; i < worker_num; ++i) {
if (!send_buffer_worker[i].empty()) {
waits[current_wait_idx].push_back(send_shuffle_data(i, send_buffer_worker[i]));
}
} }
}
// update status }
// 如果在训练期,则限速shuffle for (auto& waits_s : waits) {
// 如果在wait状态,全速shuffle for (auto& wait_s : waits_s) {
if (_trainer_context->is_status(TrainerStatus::Training)) { if (wait_s.get() != 0) {
concurrent_num = 1; LOG(WARNING) << "fail to send shuffle data";
package_size = _max_concurrent_num / 2; status = -1;
} else {
package_size = _max_package_size;
concurrent_num = _max_concurrent_num;
} }
++current_wait_idx;
current_wait_idx = current_wait_idx >= concurrent_num ? 0 : current_wait_idx;
} }
} }
return 0; VLOG(5) << "start send finish, worker_num: " << worker_num;
waits[0].clear();
for (int i = 0; i < worker_num; ++i) {
waits[0].push_back(send_finish(i));
}
VLOG(5) << "wait all finish";
for (int i = 0; i < worker_num; ++i) {
if (waits[0][i].get() != 0) {
LOG(WARNING) << "fail to send finish " << i;
status = -1;
}
}
VLOG(5) << "finish shuffler, status: " << status;
return status < 0 ? status : 0;
} }
private: private:
/*
1. 部分c2c send_shuffle_data先到, 此时channel未设置, 等待wait_channel
2. shuffle中调用set_channel, 先reset_wait_num, 再解锁channel
3. 当接收到所有worker的finish请求后,先reset_channel, 再同时返回
*/
bool wait_channel() {
VLOG(5) << "wait_channel";
std::lock_guard<bthread::Mutex> lock(_channel_mutex);
return _out_channel != nullptr;
}
void reset_channel() {
VLOG(5) << "reset_channel";
_channel_mutex.lock();
if (_out_channel != nullptr) {
_out_channel->Close();
}
_out_channel = nullptr;
}
void reset_wait_num() {
_wait_num_mutex.lock();
_wait_num = _trainer_context->environment->node_num(EnvironmentRole::WORKER);
VLOG(5) << "reset_wait_num: " << _wait_num;
}
void set_channel(paddle::framework::Channel<DataItem>& channel) {
VLOG(5) << "set_channel";
// 在节点开始写入channel之前,重置wait_num
CHECK(_out_channel == nullptr);
_out_channel = channel;
reset_wait_num();
_channel_mutex.unlock();
}
int32_t finish_write_channel() {
int wait_num = --_wait_num;
VLOG(5) << "finish_write_channel, wait_num: " << wait_num;
// 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel
if (wait_num == 0) {
reset_channel();
_wait_num_mutex.unlock();
} else {
std::lock_guard<bthread::Mutex> lock(_wait_num_mutex);
}
return 0;
}
int32_t write_to_channel(std::vector<DataItem>&& items) {
size_t items_size = items.size();
VLOG(5) << "write_to_channel, items_size: " << items_size;
return _out_channel->Write(std::move(items)) == items_size ? 0 : -1;
}
int32_t get_client2client_msg(int msg_type, int from_client, const std::string& msg) {
// wait channel
if (!wait_channel()) {
LOG(FATAL) << "out_channel is null";
return -1;
}
VLOG(5) << "get c2c msg, type: " << msg_type << ", from_client: " << from_client << ", msg_size: " << msg.size();
if (msg_type == _shuffle_data_msg_type) {
paddle::framework::BinaryArchive ar;
ar.SetReadBuffer(const_cast<char*>(msg.data()), msg.size(), [](char*){});
std::vector<DataItem> items;
ar >> items;
return write_to_channel(std::move(items));
} else if (msg_type == _finish_msg_type) {
return finish_write_channel();
}
LOG(FATAL) << "no such msg type: " << msg_type;
return -1;
}
std::future<int32_t> send_shuffle_data(int to_client_id, std::vector<DataItem>& items) {
VLOG(5) << "send_shuffle_data, to_client_id: " << to_client_id << ", items_size: " << items.size();
paddle::framework::BinaryArchive ar;
ar << items;
return _trainer_context->pslib->ps_client()->send_client2client_msg(_shuffle_data_msg_type, to_client_id,
std::string(ar.Buffer(), ar.Length()));
}
std::future<int32_t> send_finish(int to_client_id) {
VLOG(5) << "send_finish, to_client_id: " << to_client_id;
static const std::string empty_str;
return _trainer_context->pslib->ps_client()->send_client2client_msg(_finish_msg_type, to_client_id, empty_str);
}
uint32_t _max_package_size = 0; uint32_t _max_package_size = 0;
uint32_t _max_concurrent_num = 0; uint32_t _max_concurrent_num = 0;
int _shuffle_data_msg_type = 3;
int _finish_msg_type = 4;
bthread::Mutex _channel_mutex;
paddle::framework::Channel<DataItem> _out_channel = nullptr;
bthread::Mutex _wait_num_mutex;
std::atomic<int> _wait_num;
}; };
REGIST_CLASS(Shuffler, GlobalShuffler); REGIST_CLASS(Shuffler, GlobalShuffler);
......
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
TEST(Archive, DataItem) {
paddle::custom_trainer::feed::DataItem item;
paddle::custom_trainer::feed::DataItem item2;
item.id = "123";
item.data = "name";
paddle::framework::BinaryArchive ar;
ar << item;
ar >> item2;
ASSERT_EQ(item.id, item2.id);
ASSERT_EQ(item.data, item2.data);
item.id += "~";
item.data += "~";
ASSERT_NE(item.id, item2.id);
ASSERT_NE(item.data, item2.data);
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册