提交 3cea00bd 编写于 作者: X xujiaqi01 提交者: dongdaxiang

store memory data in Dataset && fix bug

上级 ff87698a
...@@ -68,8 +68,10 @@ void DataFeed::SetBatchSize(int batch_size) { ...@@ -68,8 +68,10 @@ void DataFeed::SetBatchSize(int batch_size) {
bool DataFeed::PickOneFile(std::string* filename) { bool DataFeed::PickOneFile(std::string* filename) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_); std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
if (file_idx_ == filelist_.size()) { if (file_idx_ == filelist_.size()) {
VLOG(3) << "DataFeed::PickOneFile no more file to pick";
return false; return false;
} }
VLOG(3) << "file_idx_=" << file_idx_;
*filename = filelist_[file_idx_++]; *filename = filelist_[file_idx_++];
// LOG(ERROR) << "pick file:" << *filename; // LOG(ERROR) << "pick file:" << *filename;
return true; return true;
...@@ -146,17 +148,18 @@ template class PrivateQueueDataFeed<std::vector<MultiSlotType>>; ...@@ -146,17 +148,18 @@ template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
template <typename T> template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() { InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_ = 0; cur_channel_ = 0;
shuffled_ins_ = nullptr; shuffled_ins_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
shuffled_ins_out_ = nullptr; shuffled_ins_out_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
fleet_send_batch_size_ = 10000;
} }
template <typename T> template <typename T>
bool InMemoryDataFeed<T>::Start() { bool InMemoryDataFeed<T>::Start() {
DataFeed::CheckSetFileList(); DataFeed::CheckSetFileList();
if (memory_data_.size() != 0) { if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) {
CHECK_EQ(cur_channel_, 0); FillMemoryDataToChannel();
shuffled_ins_->Extend(std::move(memory_data_)); //std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_);
std::vector<T>().swap(memory_data_); //std::vector<T>().swap(memory_data_);
} }
DataFeed::finish_start_ = true; DataFeed::finish_start_ = true;
return true; return true;
...@@ -196,6 +199,31 @@ int InMemoryDataFeed<T>::Next() { ...@@ -196,6 +199,31 @@ int InMemoryDataFeed<T>::Next() {
return DataFeed::batch_size_; return DataFeed::batch_size_;
} }
template <typename T>
void InMemoryDataFeed<T>::SetMemoryData(void* memory_data) {
memory_data_ = static_cast<std::vector<T>*>(memory_data);
}
template <typename T>
void InMemoryDataFeed<T>::SetMemoryDataMutex(std::mutex* mutex) {
mutex_for_update_memory_data_ = mutex;
}
template <typename T>
void InMemoryDataFeed<T>::SetThreadId(int thread_id) {
thread_id_ = thread_id;
}
template <typename T>
void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
template <typename T>
void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
template <typename T> template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) { void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins; T ins;
...@@ -203,11 +231,54 @@ void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) { ...@@ -203,11 +231,54 @@ void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
shuffled_ins_->Push(std::move(ins)); shuffled_ins_->Push(std::move(ins));
} }
template <typename T>
void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
VLOG(3) << "InMemoryDataFeed<T>::FillMemoryDataToChannel, thread_id=" << thread_id_;
int64_t start = 0;
int64_t end = 0;
int64_t size = memory_data_->size();
VLOG(3) << "memory_data size=" << size;
for (int64_t i = 0; i <= static_cast<int64_t>(thread_id_); ++i) {
int64_t len = size / static_cast<int64_t>(thread_num_) +
(i < (size % static_cast<int64_t>(thread_num_)));
start = end;
end += len;
}
for (int64_t i = start; i < end; ++i) {
T& t = (*memory_data_)[i];
shuffled_ins_->Push(std::move(t));
}
}
template <typename T>
void InMemoryDataFeed<T>::FillChannelToMemoryData() {
VLOG(3) << "InMemoryDataFeed<T>::FillChannelToMemoryData, thread_id=" << thread_id_;
std::vector<T> local_vec;
std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr;
if (cur_channel_ == 0) {
channel = shuffled_ins_;
} else {
channel = shuffled_ins_out_;
}
CHECK(channel != nullptr);
local_vec.reserve(channel->Size());
for (int64_t i = 0; i < channel->Size(); ++i) {
channel->Pop(local_vec[i]);
}
std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_);
lock.lock();
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end());
lock.unlock();
std::vector<T>().swap(local_vec);
}
template <typename T> template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() { void InMemoryDataFeed<T>::LoadIntoMemory() {
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() begin, thread_id=" << thread_id_;
std::vector<T> local_vec; std::vector<T> local_vec;
std::string filename; std::string filename;
while (DataFeed::PickOneFile(&filename)) { while (DataFeed::PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename << ", thread_id=" << thread_id_;
int err_no = 0; int err_no = 0;
PrivateQueueDataFeed<T>::fp_ = PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_); fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
...@@ -216,35 +287,50 @@ void InMemoryDataFeed<T>::LoadIntoMemory() { ...@@ -216,35 +287,50 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
while (ParseOneInstanceFromPipe(&instance)) { while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance); local_vec.push_back(instance);
} }
memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end()); VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() read all lines, thread_id=" << thread_id_;
{
std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_);
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end());
}
std::vector<T>().swap(local_vec); std::vector<T>().swap(local_vec);
} }
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() end, thread_id=" << thread_id_;
} }
template <typename T> template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() { void InMemoryDataFeed<T>::LocalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end()); VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() begin, thread_id=" << thread_id_;
FillMemoryDataToChannel();
VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() end, thread_id=" << thread_id_;
} }
// todo global shuffle
/*
template <typename T> template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) { void InMemoryDataFeed<T>::GlobalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end()); auto fleet_ptr = FleetWrapper::GetInstance();
for (int64_t i = 0; i < memory_data_.size(); ++i) { std::vector<std::string> send_str_vec(trainer_num_);
for (int64_t i = 0; i < memory_data_->size(); ++i) {
// todo get ins id // todo get ins id
//std::string ins_id = memory_data_[i].ins_id; //std::string ins_id = memory_data_[i].ins_id;
// todo hash // todo hash
int64_t hash_id = paddle::ps::local_random_engine()(); //int64_t hash_id = paddle::ps::local_random_engine()();
//int64_t hash_id = hash(ins_id); int64_t hash_id = 0;
int64_t node_id = hash_id % trainer_num_; int64_t node_id = hash_id % trainer_num_;
std::string str; std::string str;
SerializeIns(memory_data_[i], str); SerializeIns((*memory_data_)[i], str);
auto fleet_ptr = FleetWrapper::GetInstance(); send_str_vec[node_id] += str;
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str); if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_str_vec.size(); ++j) {
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]);
send_str_vec[j] = "";
}
}
}
for (int j = 0; j < send_str_vec.size(); ++j) {
if (send_str_vec[j].length() != 0) {
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]);
}
} }
} }
*/
// explicit instantiation // explicit instantiation
template class InMemoryDataFeed<std::vector<MultiSlotType>>; template class InMemoryDataFeed<std::vector<MultiSlotType>>;
...@@ -646,6 +732,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance( ...@@ -646,6 +732,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(
if (getline(file_, line)) { if (getline(file_, line)) {
int use_slots_num = use_slots_.size(); int use_slots_num = use_slots_.size();
instance->resize(use_slots_num); instance->resize(use_slots_num);
VLOG(3) << line;
// parse line // parse line
const char* str = line.c_str(); const char* str = line.c_str();
char* endptr = const_cast<char*>(str); char* endptr = const_cast<char*>(str);
...@@ -735,12 +822,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -735,12 +822,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle // todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns( void MultiSlotInMemoryDataFeed::SerializeIns(
const std::vector<MultiSlotType>& ins, std::string& str) { const std::vector<MultiSlotType>& ins, std::string& str) {
return; auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Serialize(ins, str);
} }
// todo deserialize ins in global shuffle // todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>& ins, void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>& ins,
const std::string& str) { const std::string& str) {
return; auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Deserialize(ins, str);
} }
} // namespace framework } // namespace framework
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include <sstream>
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -78,17 +79,33 @@ class DataFeed { ...@@ -78,17 +79,33 @@ class DataFeed {
// This function is used for binding feed_vec memory // This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name); virtual void AddFeedVar(Variable* var, const std::string& name);
// This function will do nothing at default
virtual void SetMemoryData(void* memory_data) { }
// This function will do nothing at default
virtual void SetMemoryDataMutex(std::mutex* mutex) { }
// This function will do nothing at default
virtual void SetThreadId(int thread_id) { }
// This function will do nothing at default
virtual void SetThreadNum(int thread_num) { }
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) { }
virtual void LoadIntoMemory() { virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented."); PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
} }
virtual void LocalShuffle() { virtual void LocalShuffle() {
PADDLE_THROW("This function(LocalShuffle) is not implemented."); PADDLE_THROW("This function(LocalShuffle) is not implemented.");
} }
virtual void GlobalShuffle(int trainer_num) { virtual void GlobalShuffle() {
PADDLE_THROW("This function(GlobalShuffle) is not implemented."); PADDLE_THROW("This function(GlobalShuffle) is not implemented.");
} }
virtual void FillMemoryDataToChannel() {
PADDLE_THROW("This function(FillMemoryDataToChannel) is not implemented.");
}
virtual void FillChannelToMemoryData() {
PADDLE_THROW("This function(FillChannelToMemoryData) is not implemented.");
}
virtual void PutInsToChannel(const std::string& ins_str) { virtual void PutInsToChannel(const std::string& ins_str) {
PADDLE_THROW("This function(PutToChannel) is not implemented."); PADDLE_THROW("This function(PutInsToChannel) is not implemented.");
} }
protected: protected:
...@@ -181,13 +198,20 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -181,13 +198,20 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public: public:
InMemoryDataFeed(); InMemoryDataFeed();
virtual ~InMemoryDataFeed() {} virtual ~InMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start(); virtual bool Start();
virtual int Next(); virtual int Next();
virtual void SetMemoryData(void* memory_data);
virtual void SetMemoryDataMutex(std::mutex* mutex);
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void PutInsToChannel(const std::string& ins_str); virtual void PutInsToChannel(const std::string& ins_str);
virtual void FillMemoryDataToChannel();
virtual void FillChannelToMemoryData();
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
// todo global shuffle virtual void GlobalShuffle();
//virtual void GlobalShuffle(int trainer_num);
protected: protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0; virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0; virtual bool ParseOneInstance(T* instance) = 0;
...@@ -196,13 +220,18 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -196,13 +220,18 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void SerializeIns(const T& ins, std::string& str) = 0; virtual void SerializeIns(const T& ins, std::string& str) = 0;
virtual void DeserializeIns(T& ins, const std::string& str) = 0; virtual void DeserializeIns(T& ins, const std::string& str) = 0;
std::vector<T> memory_data_; int thread_id_;
int thread_num_;
int trainer_num_;
std::vector<T>* memory_data_;
std::mutex* mutex_for_update_memory_data_;
// when read ins, we put ins from one channel to the other, // when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel, // and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_ // so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int cur_channel_; int cur_channel_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_; std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_; std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_;
}; };
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed // This class define the data type of instance(ins_vec) in MultiSlotDataFeed
...@@ -226,6 +255,7 @@ class MultiSlotType { ...@@ -226,6 +255,7 @@ class MultiSlotType {
offset_[0] = 0; offset_[0] = 0;
} }
const std::vector<size_t>& GetOffset() const { return offset_; } const std::vector<size_t>& GetOffset() const { return offset_; }
std::vector<size_t>& MutableOffset() { return offset_; }
void AddValue(const float v) { void AddValue(const float v) {
CheckFloat(); CheckFloat();
float_feasign_.push_back(v); float_feasign_.push_back(v);
...@@ -248,8 +278,11 @@ class MultiSlotType { ...@@ -248,8 +278,11 @@ class MultiSlotType {
} }
} }
const std::vector<float>& GetFloatData() const { return float_feasign_; } const std::vector<float>& GetFloatData() const { return float_feasign_; }
std::vector<float>& MutableFloatData() { return float_feasign_; }
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; } const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
const std::string& GetType() const { return type_; } const std::string& GetType() const { return type_; }
std::string& MutableType() { return type_; }
private: private:
void CheckType(const std::string& type) const { void CheckType(const std::string& type) const {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include <random>
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
...@@ -21,23 +22,27 @@ ...@@ -21,23 +22,27 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
Dataset::Dataset() { thread_num_ = 1; } template <typename T>
DatasetImpl<T>::DatasetImpl() { thread_num_ = 1; }
void Dataset::SetFileList(const std::vector<std::string>& filelist) { template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size(); VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist; filelist_ = filelist;
/*
int file_cnt = filelist_.size(); int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) { if (thread_num_ > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num_ VLOG(1) << "DataSet thread num = " << thread_num_
<< ", file num = " << file_cnt << ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt; << ". Changing DataSet thread num = " << file_cnt;
thread_num_ = file_cnt; thread_num_ = file_cnt;
} }*/
} }
// buggy here, a user should set filelist first before this function // buggy here, a user should set filelist first before this function
// not user friendly // not user friendly
void Dataset::SetThreadNum(int thread_num) { template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size(); int file_cnt = filelist_.size();
if (file_cnt != 0 && thread_num > file_cnt) { if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num VLOG(1) << "DataSet thread num = " << thread_num
...@@ -48,19 +53,24 @@ void Dataset::SetThreadNum(int thread_num) { ...@@ -48,19 +53,24 @@ void Dataset::SetThreadNum(int thread_num) {
thread_num_ = thread_num; thread_num_ = thread_num;
} }
void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) { template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc_); &data_feed_desc_);
} }
const std::vector<std::shared_ptr<paddle::framework::DataFeed>>& template <typename T>
Dataset::GetReaders() { std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
DatasetImpl<T>::GetReaders() {
return readers_; return readers_;
} }
void Dataset::LoadIntoMemory() { template <typename T>
void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
if (readers_.size() == 0) { if (readers_.size() == 0) {
CreateReaders(); CreateReaders();
} }
...@@ -72,12 +82,18 @@ void Dataset::LoadIntoMemory() { ...@@ -72,12 +82,18 @@ void Dataset::LoadIntoMemory() {
for (std::thread& t : load_threads) { for (std::thread& t : load_threads) {
t.join(); t.join();
} }
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end";
} }
void Dataset::LocalShuffle() { template <typename T>
void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
if (readers_.size() == 0) { if (readers_.size() == 0) {
CreateReaders(); CreateReaders();
} }
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
std::vector<std::thread> local_shuffle_threads; std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) { for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread( local_shuffle_threads.push_back(std::thread(
...@@ -86,30 +102,37 @@ void Dataset::LocalShuffle() { ...@@ -86,30 +102,37 @@ void Dataset::LocalShuffle() {
for (std::thread& t : local_shuffle_threads) { for (std::thread& t : local_shuffle_threads) {
t.join(); t.join();
} }
std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end";
} }
// todo global shuffle template <typename T>
void Dataset::GlobalShuffle() { void DatasetImpl<T>::GlobalShuffle() {
/* VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->registe_client2client_msg_handler(0, fleet_ptr->registe_client2client_msg_handler(0,
[this](int msg_type, int client_id, const std::string& msg) -> int { [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg); return this->ReceiveFromClient(msg_type, client_id, msg);
}); });
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> global_shuffle_threads; std::vector<std::thread> global_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::GlobalShuffle, global_shuffle_threads.push_back(
readers_[i].get(), trainer_num_)); std::thread(&paddle::framework::DataFeed::GlobalShuffle,
readers_[i].get()));
} }
for (std::thread& t : global_shuffle_threads) { for (std::thread& t : global_shuffle_threads) {
t.join(); t.join();
}*/ }
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end";
} }
void Dataset::CreateReaders() { template <typename T>
void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()"; VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0"; CHECK(thread_num_ > 0) << "thread_num should > 0";
VLOG(3) << "thread_num in Readers: " << thread_num_; VLOG(3) << "thread_num in Readers: " << thread_num_;
...@@ -118,22 +141,53 @@ void Dataset::CreateReaders() { ...@@ -118,22 +141,53 @@ void Dataset::CreateReaders() {
return; return;
} }
VLOG(3) << "data feed class name: " << data_feed_desc_.name(); VLOG(3) << "data feed class name: " << data_feed_desc_.name();
for (int64_t i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name())); readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_); readers_.back()->Init(data_feed_desc_);
readers_.back()->SetMemoryData(&memory_data_);
readers_.back()->SetMemoryDataMutex(&mutex_for_update_memory_data_);
readers_.back()->SetThreadId(i);
readers_.back()->SetThreadNum(thread_num_);
readers_.back()->SetTrainerNum(trainer_num_);
} }
VLOG(3) << "Filelist size in readers: " << filelist_.size(); VLOG(3) << "Filelist size in readers: " << filelist_.size();
readers_[0]->SetFileList(filelist_); readers_[0]->SetFileList(filelist_);
} }
int Dataset::ReceiveFromClient(int msg_type, int client_id, template <typename T>
void DatasetImpl<T>::DestroyReaders() {
VLOG(3) << "Calling DestroyReaders()";
// clear memory_data_ before fill it
// because if LoadIntoMemory but no Shuffle,
// memory_data_ has empty data which has been std::move to channel
if (memory_data_.size() != 0) {
std::vector<T>().swap(memory_data_);
}
std::vector<std::thread> fill_threads;
for (int i = 0; i < thread_num_; ++i) {
fill_threads.push_back(std::thread(
&paddle::framework::DataFeed::FillChannelToMemoryData,
readers_[i].get()));
}
for (std::thread& t : fill_threads) {
t.join();
}
std::vector<std::string>().swap(filelist_);
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
}
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) { const std::string& msg) {
// can also use hash // todo random
// int64_t index = paddle::ps::local_random_engine()() % thread_num_; // int64_t index = paddle::ps::local_random_engine()() % thread_num_;
int64_t index = 0; int64_t index = 0;
readers_[index]->PutInsToChannel(msg); readers_[index]->PutInsToChannel(msg);
return 0; return 0;
} }
// explicit instantiation
template class DatasetImpl<std::vector<MultiSlotType>>;
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -28,8 +28,33 @@ namespace framework { ...@@ -28,8 +28,33 @@ namespace framework {
class Dataset { class Dataset {
public: public:
Dataset(); Dataset() {};
virtual ~Dataset() {} virtual ~Dataset() {};
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
virtual void SetThreadNum(int thread_num) = 0;
virtual void SetTrainerNum(int trainer_num) = 0;
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
virtual const std::vector<std::string>& GetFileList() = 0;
virtual int GetThreadNum() = 0;
virtual int GetTrainerNum() = 0;
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders() = 0;
virtual void LoadIntoMemory() = 0;
virtual void LocalShuffle() = 0;
virtual void GlobalShuffle() = 0;
virtual void CreateReaders() = 0;
virtual void DestroyReaders() = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
};
template<typename T>
class DatasetImpl : public Dataset {
public:
DatasetImpl();
virtual ~DatasetImpl() {}
virtual void SetFileList(const std::vector<std::string>& filelist); virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
...@@ -43,25 +68,34 @@ class Dataset { ...@@ -43,25 +68,34 @@ class Dataset {
return data_feed_desc_; return data_feed_desc_;
} }
virtual const std::vector<std::shared_ptr<paddle::framework::DataFeed>>& virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders(); GetReaders();
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
// todo global shuffle
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual void CreateReaders(); virtual void CreateReaders();
virtual void DestroyReaders();
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg); const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_; std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<T> memory_data_;
std::mutex mutex_for_update_memory_data_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_vec_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_out_vec_;
int thread_num_; int thread_num_;
std::string fs_name_;
std::string fs_ugi_;
paddle::framework::DataFeedDesc data_feed_desc_; paddle::framework::DataFeedDesc data_feed_desc_;
std::vector<std::string> filelist_; std::vector<std::string> filelist_;
int trainer_num_; int trainer_num_;
}; };
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
public:
MultiSlotDataset() {}
virtual ~MultiSlotDataset() {}
};
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and ...@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/data_feed.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -35,6 +36,30 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; ...@@ -35,6 +36,30 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL; std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false; bool FleetWrapper::is_initialized_ = false;
#ifdef PADDLE_WITH_PSLIB
template<class AR>
paddle::ps::Archive<AR>& operator << (
paddle::ps::Archive<AR>& ar,
const MultiSlotType& ins) {
ar << ins.GetType();
ar << ins.GetOffset();
ar << ins.GetFloatData();
ar << ins.GetUint64Data();
return ar;
}
template<class AR>
paddle::ps::Archive<AR>& operator >> (
paddle::ps::Archive<AR>& ar,
MultiSlotType& ins) {
ar >> ins.MutableType();
ar >> ins.MutableOffset();
ar >> ins.MutableFloatData();
ar >> ins.MutableUint64Data();
return ar;
}
#endif
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL; std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
#endif #endif
...@@ -266,5 +291,42 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -266,5 +291,42 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif #endif
} }
// todo registe_client2client_msg_handler
int FleetWrapper::registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler) {
return 0;
}
// todo send_client2client_msg
int FleetWrapper::send_client2client_msg(int msg_type, int to_client_id, const std::string& msg) {
return 0;
}
template<typename T>
void FleetWrapper::Serialize(const T& t, std::string& str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
ar << t;
str = std::string(ar.buffer(), ar.length());
#else
VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib";
#endif
}
template<typename T>
void FleetWrapper::Deserialize(T& t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
t = ar.get<T>();
#else
VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib";
#endif
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<MultiSlotType>&, std::string&);
template void FleetWrapper::Deserialize(
std::vector<MultiSlotType>&, const std::string&);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -17,7 +17,11 @@ limitations under the License. */ ...@@ -17,7 +17,11 @@ limitations under the License. */
#include <memory> #include <memory>
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
#include <pslib.h> #include <pslib.h>
#include <archive.h>
#endif #endif
#include <random>
#include <atomic>
#include <time.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -110,6 +114,16 @@ class FleetWrapper { ...@@ -110,6 +114,16 @@ class FleetWrapper {
uint64_t RunServer(); uint64_t RunServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num); void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc;
int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler);
int send_client2client_msg(int msg_type, int to_client_id, const std::string& msg);
std::default_random_engine& local_random_engine();
template<typename T>
void Serialize(const T& t, std::string& str);
template<typename T>
void Deserialize(T& t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() { static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) { if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper()); s_instance_.reset(new paddle::framework::FleetWrapper());
......
...@@ -41,17 +41,17 @@ namespace paddle { ...@@ -41,17 +41,17 @@ namespace paddle {
namespace pybind { namespace pybind {
void BindDataset(py::module* m) { void BindDataset(py::module* m) {
py::class_<framework::Dataset>(*m, "Dataset") py::class_<framework::MultiSlotDataset>(*m, "MultiSlotDataset")
.def(py::init([]() { .def(py::init([]() {
return std::unique_ptr<framework::Dataset>(new framework::Dataset()); return std::unique_ptr<framework::MultiSlotDataset>(new framework::MultiSlotDataset());
})) }))
.def("set_filelist", &framework::Dataset::SetFileList) .def("set_filelist", &framework::MultiSlotDataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum) .def("set_thread_num", &framework::MultiSlotDataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum) .def("set_trainer_num", &framework::MultiSlotDataset::SetTrainerNum)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) .def("set_data_feed_desc", &framework::MultiSlotDataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory) .def("load_into_memory", &framework::MultiSlotDataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle) .def("local_shuffle", &framework::MultiSlotDataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle); .def("global_shuffle", &framework::MultiSlotDataset::GlobalShuffle);
} }
} // end namespace pybind } // end namespace pybind
......
...@@ -30,7 +30,7 @@ from .dataset import * ...@@ -30,7 +30,7 @@ from .dataset import *
from . import async_executor from . import async_executor
from .async_executor import * from .async_executor import *
from . import trainer from . import trainer_desc
from . import inferencer from . import inferencer
from . import io from . import io
...@@ -67,7 +67,7 @@ from . import install_check ...@@ -67,7 +67,7 @@ from . import install_check
Tensor = LoDTensor Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \ __all__ = framework.__all__ + executor.__all__ + \
trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \ trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + \ parallel_executor.__all__ + lod_tensor.__all__ + \
data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [ data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [
'io', 'io',
......
...@@ -37,7 +37,7 @@ class DatasetBase(object): ...@@ -37,7 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance # to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat" self.proto_desc.pipe_command = "cat"
self.dataset = core.Dataset() self.dataset = core.MultiSlotDataset()
self.thread_num = 0 self.thread_num = 0
def set_pipe_command(self, pipe_command): def set_pipe_command(self, pipe_command):
...@@ -109,7 +109,7 @@ class InMemoryDataset(DatasetBase): ...@@ -109,7 +109,7 @@ class InMemoryDataset(DatasetBase):
self.proto_desc.name = "MultiSlotInMemoryDataFeed" self.proto_desc.name = "MultiSlotInMemoryDataFeed"
def load_into_memory(self): def load_into_memory(self):
_prepare_to_run() self._prepare_to_run()
self.dataset.load_into_memory() self.dataset.load_into_memory()
def local_shuffle(self): def local_shuffle(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册