diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h index 00dc1138776274426252db7488e1154f2cfb736b..b190f5ea28e68dc57f5b5cd3e89da22106679a1f 100755 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h @@ -68,6 +68,16 @@ public: std::string data;//样本数据, maybe压缩格式 }; +template +paddle::framework::Archive& operator>>(paddle::framework::Archive& ar, DataItem& x) { + return ar >> x.id >> x.data; +} + +template +paddle::framework::Archive& operator<<(paddle::framework::Archive& ar, const DataItem& x) { + return ar << x.id << x.data; +} + typedef std::shared_ptr> SampleInstancePipe; inline SampleInstancePipe make_sample_instance_channel() { return std::make_shared>(); diff --git a/paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc b/paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc index 318beca6a72da51d4cb01b13cf8102a9b37b1a4b..10a61fcbdafb03466b07c11f8509badb0dfb3819 100644 --- a/paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc +++ b/paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc @@ -1,6 +1,7 @@ #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h" +#include namespace paddle { namespace custom_trainer { @@ -40,73 +41,195 @@ public: Shuffler::initialize(config, context_ptr); _max_concurrent_num = config["max_concurrent_num"].as(4); // 最大并发发送数 _max_package_size = config["max_package_size"].as(1024); // 最大包个数,一次发送package个数据 + _shuffle_data_msg_type = config["shuffle_data_msg_type"].as(3); // c2c msg type + _finish_msg_type = config["finish_msg_type"].as(4); // c2c msg type + + reset_channel(); + auto binded = std::bind(&GlobalShuffler::get_client2client_msg, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); + _trainer_context->pslib->ps_client()->registe_client2client_msg_handler(_shuffle_data_msg_type, + binded); + _trainer_context->pslib->ps_client()->registe_client2client_msg_handler(_finish_msg_type, + binded); return 0; } + + // 所有worker必须都调用shuffle,并且shuffler同时只能有一个shuffle任务 virtual int shuffle(::paddle::framework::Channel& data_channel) { uint32_t send_count = 0; uint32_t package_size = _max_package_size; uint32_t concurrent_num = _max_concurrent_num; - uint32_t current_wait_idx = 0; + ::paddle::framework::Channel input_channel = ::paddle::framework::MakeChannel(data_channel); + data_channel.swap(input_channel); + set_channel(data_channel); + auto* environment = _trainer_context->environment.get(); auto worker_num = environment->node_num(EnvironmentRole::WORKER); std::vector>> waits(concurrent_num); - std::vector send_buffer(concurrent_num * package_size); - std::vector request_data_buffer(worker_num); - while (true) { - auto read_size = data_channel->Read(concurrent_num * package_size, &send_buffer[0]); - if (read_size == 0) { - break; + std::vector send_buffer(package_size); + std::vector> send_buffer_worker(worker_num); + + int status = 0;// >0: finish; =0: running; <0: fail + while (status == 0) { + // update status + // 如果在训练期,则限速shuffle + // 如果在wait状态,全速shuffle + if (_trainer_context->is_status(TrainerStatus::Training)) { + concurrent_num = 1; + package_size = _max_concurrent_num / 2; + } else { + package_size = _max_package_size; + concurrent_num = _max_concurrent_num; } - for (size_t idx = 0; idx < read_size; idx += package_size) { - // data shard && seriliaze - for (size_t i = 0; i < worker_num; ++i) { - request_data_buffer[i].Clear(); + for (uint32_t current_wait_idx = 0; status == 0 && current_wait_idx < concurrent_num; ++current_wait_idx) { + auto read_size = input_channel->Read(package_size, send_buffer.data()); + if (read_size == 0) { + status = 1; + break; } - for (size_t i = idx; i < package_size && i < read_size; ++i) { - auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num; - // TODO Serialize To Arcive - //request_data_buffer[worker_idx] << send_buffer[i]; + for (int i = 0; i < worker_num; ++i) { + send_buffer_worker.clear(); } - std::string data_vec[worker_num]; - for (size_t i = 0; i < worker_num; ++i) { - auto& buffer = request_data_buffer[i]; - data_vec[i].assign(buffer.Buffer(), buffer.Length()); + for (int i = 0; i < read_size; ++i) { + auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num; + send_buffer_worker[worker_idx].push_back(std::move(send_buffer[i])); } - - // wait async done for (auto& wait_s : waits[current_wait_idx]) { - if (!wait_s.valid()) { + if (wait_s.get() != 0) { + LOG(WARNING) << "fail to send shuffle data"; + status = -1; break; } - CHECK(wait_s.get() == 0); } - - // send shuffle data - for (size_t i = 0; i < worker_num; ++i) { - waits[current_wait_idx][i] = _trainer_context->pslib->ps_client()->send_client2client_msg(3, i * 2, data_vec[i]); + if (status != 0) { + break; + } + waits[current_wait_idx].clear(); + for (int i = 0; i < worker_num; ++i) { + if (!send_buffer_worker[i].empty()) { + waits[current_wait_idx].push_back(send_shuffle_data(i, send_buffer_worker[i])); + } } - - // update status - // 如果在训练期,则限速shuffle - // 如果在wait状态,全速shuffle - if (_trainer_context->is_status(TrainerStatus::Training)) { - concurrent_num = 1; - package_size = _max_concurrent_num / 2; - } else { - package_size = _max_package_size; - concurrent_num = _max_concurrent_num; + } + } + for (auto& waits_s : waits) { + for (auto& wait_s : waits_s) { + if (wait_s.get() != 0) { + LOG(WARNING) << "fail to send shuffle data"; + status = -1; } - ++current_wait_idx; - current_wait_idx = current_wait_idx >= concurrent_num ? 0 : current_wait_idx; } } - return 0; + VLOG(5) << "start send finish, worker_num: " << worker_num; + waits[0].clear(); + for (int i = 0; i < worker_num; ++i) { + waits[0].push_back(send_finish(i)); + } + VLOG(5) << "wait all finish"; + for (int i = 0; i < worker_num; ++i) { + if (waits[0][i].get() != 0) { + LOG(WARNING) << "fail to send finish " << i; + status = -1; + } + } + VLOG(5) << "finish shuffler, status: " << status; + return status < 0 ? status : 0; } private: + /* + 1. 部分c2c send_shuffle_data先到, 此时channel未设置, 等待wait_channel + 2. shuffle中调用set_channel, 先reset_wait_num, 再解锁channel + 3. 当接收到所有worker的finish请求后,先reset_channel, 再同时返回 + */ + bool wait_channel() { + VLOG(5) << "wait_channel"; + std::lock_guard lock(_channel_mutex); + return _out_channel != nullptr; + } + void reset_channel() { + VLOG(5) << "reset_channel"; + _channel_mutex.lock(); + if (_out_channel != nullptr) { + _out_channel->Close(); + } + _out_channel = nullptr; + } + void reset_wait_num() { + _wait_num_mutex.lock(); + _wait_num = _trainer_context->environment->node_num(EnvironmentRole::WORKER); + VLOG(5) << "reset_wait_num: " << _wait_num; + } + void set_channel(paddle::framework::Channel& channel) { + VLOG(5) << "set_channel"; + // 在节点开始写入channel之前,重置wait_num + CHECK(_out_channel == nullptr); + _out_channel = channel; + reset_wait_num(); + _channel_mutex.unlock(); + } + + int32_t finish_write_channel() { + int wait_num = --_wait_num; + VLOG(5) << "finish_write_channel, wait_num: " << wait_num; + // 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel + if (wait_num == 0) { + reset_channel(); + _wait_num_mutex.unlock(); + } else { + std::lock_guard lock(_wait_num_mutex); + } + return 0; + } + int32_t write_to_channel(std::vector&& items) { + size_t items_size = items.size(); + VLOG(5) << "write_to_channel, items_size: " << items_size; + return _out_channel->Write(std::move(items)) == items_size ? 0 : -1; + } + + int32_t get_client2client_msg(int msg_type, int from_client, const std::string& msg) { + // wait channel + if (!wait_channel()) { + LOG(FATAL) << "out_channel is null"; + return -1; + } + VLOG(5) << "get c2c msg, type: " << msg_type << ", from_client: " << from_client << ", msg_size: " << msg.size(); + if (msg_type == _shuffle_data_msg_type) { + paddle::framework::BinaryArchive ar; + ar.SetReadBuffer(const_cast(msg.data()), msg.size(), [](char*){}); + std::vector items; + ar >> items; + return write_to_channel(std::move(items)); + } else if (msg_type == _finish_msg_type) { + return finish_write_channel(); + } + LOG(FATAL) << "no such msg type: " << msg_type; + return -1; + } + + std::future send_shuffle_data(int to_client_id, std::vector& items) { + VLOG(5) << "send_shuffle_data, to_client_id: " << to_client_id << ", items_size: " << items.size(); + paddle::framework::BinaryArchive ar; + ar << items; + return _trainer_context->pslib->ps_client()->send_client2client_msg(_shuffle_data_msg_type, to_client_id, + std::string(ar.Buffer(), ar.Length())); + } + + std::future send_finish(int to_client_id) { + VLOG(5) << "send_finish, to_client_id: " << to_client_id; + static const std::string empty_str; + return _trainer_context->pslib->ps_client()->send_client2client_msg(_finish_msg_type, to_client_id, empty_str); + } + uint32_t _max_package_size = 0; uint32_t _max_concurrent_num = 0; + int _shuffle_data_msg_type = 3; + int _finish_msg_type = 4; + bthread::Mutex _channel_mutex; + paddle::framework::Channel _out_channel = nullptr; + + bthread::Mutex _wait_num_mutex; + std::atomic _wait_num; }; REGIST_CLASS(Shuffler, GlobalShuffler); diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_archive_dataitem.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_archive_dataitem.cc new file mode 100644 index 0000000000000000000000000000000000000000..a85da7fb0d4cacacbde813dec73ae65818b552db --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_archive_dataitem.cc @@ -0,0 +1,20 @@ +#include +#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h" + +TEST(Archive, DataItem) { + paddle::custom_trainer::feed::DataItem item; + paddle::custom_trainer::feed::DataItem item2; + item.id = "123"; + item.data = "name"; + + paddle::framework::BinaryArchive ar; + ar << item; + ar >> item2; + + ASSERT_EQ(item.id, item2.id); + ASSERT_EQ(item.data, item2.data); + item.id += "~"; + item.data += "~"; + ASSERT_NE(item.id, item2.id); + ASSERT_NE(item.data, item2.data); +} \ No newline at end of file