提交 a5b1a0e1 编写于 作者: X xujiaqi01 提交者: dongdaxiang

support multi dataset && add init model && fix bug

上级 3c65cc1b
......@@ -155,7 +155,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
}
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
_pull_dense_thread->stop();
// todo ?
//_pull_dense_thread->stop();
}
#endif
VLOG(3) << "start to run from files in async_executor";
......
......@@ -23,15 +23,11 @@ limitations under the License. */
#include "io/shell.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
std::vector<std::string> DataFeed::filelist_;
size_t DataFeed::file_idx_;
std::mutex DataFeed::mutex_for_pick_file_;
bool DataFeed::finish_set_filelist_;
void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
......@@ -42,7 +38,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
}
bool DataFeed::SetFileList(const std::vector<std::string>& files) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
CheckInit();
// Do not set finish_set_filelist_ flag,
// since a user may set file many times after init reader
......@@ -52,9 +48,8 @@ bool DataFeed::SetFileList(const std::vector<std::string>& files) {
return false;
}
*/
PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
//PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
filelist_.assign(files.begin(), files.end());
file_idx_ = 0;
finish_set_filelist_ = true;
return true;
......@@ -66,13 +61,17 @@ void DataFeed::SetBatchSize(int batch_size) {
}
bool DataFeed::PickOneFile(std::string* filename) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
if (file_idx_ == filelist_.size()) {
PADDLE_ENFORCE(mutex_for_pick_file_ != nullptr,
"should call SetFileListMutex before PickOneFile");
PADDLE_ENFORCE(file_idx_ != nullptr,
"should call SetFileListIndex before PickOneFile");
std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
if (*file_idx_ == filelist_.size()) {
VLOG(3) << "DataFeed::PickOneFile no more file to pick";
return false;
}
VLOG(3) << "file_idx_=" << file_idx_;
*filename = filelist_[file_idx_++];
VLOG(3) << "file_idx_=" << *file_idx_;
*filename = filelist_[(*file_idx_)++];
// LOG(ERROR) << "pick file:" << *filename;
return true;
}
......@@ -150,7 +149,11 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_ = 0;
shuffled_ins_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
shuffled_ins_out_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
fleet_send_batch_size_ = 10000;
fleet_send_batch_size_ = 80000;
memory_data_ = nullptr;
mutex_for_update_memory_data_ = nullptr;
this->file_idx_ = nullptr;
this->mutex_for_pick_file_ = nullptr;
}
template <typename T>
......@@ -192,6 +195,8 @@ int InMemoryDataFeed<T>::Next() {
out_channel->Push(std::move(instance));
}
DataFeed::batch_size_ = index;
VLOG(3) << "batch_size_=" << DataFeed::batch_size_
<< ", thread_id=" << thread_id_;
if (DataFeed::batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
......@@ -227,25 +232,22 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins;
std::vector<T> ins;
DeserializeIns(&ins, ins_str);
shuffled_ins_->Push(std::move(ins));
shuffled_ins_->Extend(std::move(ins));
VLOG(3) << "PutInsToChannel put ins num=" << ins.size()
<< " to channel, channel size=" << shuffled_ins_->Size()
<< " thread_id=" << thread_id_;
}
template <typename T>
void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_;
int64_t start = 0;
int64_t end = 0;
int64_t size = memory_data_->size();
VLOG(3) << "memory_data size=" << size;
for (int64_t i = 0; i <= static_cast<int64_t>(thread_id_); ++i) {
int64_t len = size / static_cast<int64_t>(thread_num_) +
(i < (size % static_cast<int64_t>(thread_num_)));
start = end;
end += len;
}
for (int64_t i = start; i < end; ++i) {
auto interval = GetMemoryDataInterval();
VLOG(3) << "memory data size=" << memory_data_->size()
<< ", fill data from [" << interval.first << ", "
<< interval.second << "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) {
T& t = (*memory_data_)[i];
shuffled_ins_->Push(std::move(t));
}
......@@ -256,14 +258,19 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
VLOG(3) << "FillChannelToMemoryData, thread_id=" << thread_id_;
std::vector<T> local_vec;
std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> pre_channel = nullptr;
if (cur_channel_ == 0) {
channel = shuffled_ins_;
pre_channel = shuffled_ins_out_;
} else {
channel = shuffled_ins_out_;
pre_channel = shuffled_ins_;
}
CHECK(channel != nullptr);
CHECK(pre_channel != nullptr);
CHECK(pre_channel->Size() == 0);
local_vec.resize(channel->Size());
for (int64_t i = 0; i < channel->Size(); ++i) {
for (int64_t i = 0; i < local_vec.size(); ++i) {
channel->Pop(local_vec[i]);
}
VLOG(3) << "local_vec size=" << local_vec.size() <<", thread_id=" << thread_id_;
......@@ -289,20 +296,32 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
int err_no = 0;
PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
CHECK(PrivateQueueDataFeed<T>::fp_ != nullptr);
__fsetlocking(&*PrivateQueueDataFeed<T>::fp_, FSETLOCKING_BYCALLER);
T instance;
platform::Timer timeline;
timeline.Start();
while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance);
}
timeline.Pause();
VLOG(3) << "LoadIntoMemory() read all lines, file="
<< filename <<", thread_id=" << thread_id_;
<< filename << ", cost time=" << timeline.ElapsedSec()
<< " seconds, thread_id=" << thread_id_;
{
std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_);
timeline.Start();
memory_data_->insert(memory_data_->end(),
local_vec.begin(), local_vec.end());
std::make_move_iterator(local_vec.begin()),
std::make_move_iterator(local_vec.end()));
timeline.Pause();
VLOG(3) << "LoadIntoMemory() memory_data insert, cost time="
<< timeline.ElapsedSec() << " seconds, thread_id="
<< thread_id_;
}
std::vector<T>().swap(local_vec);
local_vec.clear();
}
std::vector<T>().swap(local_vec);
VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_;
}
......@@ -315,30 +334,66 @@ void InMemoryDataFeed<T>::LocalShuffle() {
template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle() {
VLOG(3) << "GlobalShuffle(), thread_id=" << thread_id_;
VLOG(3) << "GlobalShuffle() begin, thread_id=" << thread_id_;
auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::string> send_str_vec(trainer_num_);
for (int64_t i = 0; i < memory_data_->size(); ++i) {
// todo get ins id
std::vector<std::vector<T*>> send_vec(trainer_num_);
for (auto& vec : send_vec) {
vec.reserve(fleet_send_batch_size_);
}
std::vector<std::future<int32_t>> total_status;
auto interval = GetMemoryDataInterval();
VLOG(3) << "global shuffle data from [" << interval.first << ", "
<< interval.second << "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) {
// if get ins id, can also use hash
// std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t random_num = fleet_ptr->LocalRandomEngine()();
int64_t node_id = random_num % trainer_num_;
std::string str;
SerializeIns((*memory_data_)[i], &str);
send_str_vec[node_id] += str;
send_vec[node_id].push_back(&((*memory_data_)[i]));
if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_str_vec.size(); ++j) {
fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]);
send_str_vec[j] = "";
for (int j = 0; j < send_vec.size(); ++j) {
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
<< ", ins num=" << send_vec[j].size() << " to node_id="
<< j << ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
send_vec[j].clear();
total_status.push_back(std::move(ret));
}
}
}
for (int j = 0; j < send_vec.size(); ++j) {
if (send_vec[j].size() != 0) {
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
<< " to node_id=" << j << ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
total_status.push_back(std::move(ret));
}
for (int j = 0; j < send_str_vec.size(); ++j) {
if (send_str_vec[j].length() != 0) {
fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]);
std::vector<T*>().swap(send_vec[j]);
}
for (auto& t : total_status) {
t.wait();
}
VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_;
}
template <typename T>
std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() {
int64_t start = 0;
int64_t end = 0;
int64_t size = memory_data_->size();
for (int64_t i = 0; i <= static_cast<int64_t>(thread_id_); ++i) {
int64_t len = size / static_cast<int64_t>(thread_num_) +
(i < (size % static_cast<int64_t>(thread_num_)));
start = end;
end += len;
}
return std::make_pair(start, end);
}
// explicit instantiation
......@@ -519,7 +574,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
const char* str = reader.get();
std::string line = std::string(str);
VLOG(3) << line;
//VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
......@@ -695,7 +750,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
const char* str = reader.get();
std::string line = std::string(str);
VLOG(3) << line;
//VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
......@@ -830,12 +885,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns(
const std::vector<MultiSlotType>& ins, std::string* str) {
const std::vector<std::vector<MultiSlotType>*>& ins,
std::string* str) {
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Serialize(ins, str);
}
// todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>* ins,
void MultiSlotInMemoryDataFeed::DeserializeIns(
std::vector<std::vector<MultiSlotType>>* ins,
const std::string& str) {
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Deserialize(ins, str);
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include <sstream>
#include <future>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -52,7 +53,10 @@ namespace framework {
// }
class DataFeed {
public:
DataFeed() {}
DataFeed() {
mutex_for_pick_file_ = nullptr;
file_idx_ = nullptr;
}
virtual ~DataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) {
......@@ -89,6 +93,12 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) { }
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) { }
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
virtual void SetFileListIndex(size_t* file_index) {
file_idx_ = file_index;
}
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
......@@ -100,7 +110,9 @@ class DataFeed {
}
// This function will do nothing at default
virtual void FillMemoryDataToChannel() { }
// This function will do nothing at default
virtual void FillChannelToMemoryData() { }
// This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) { }
protected:
......@@ -116,9 +128,9 @@ class DataFeed {
// safe).
virtual bool PickOneFile(std::string* filename);
static std::vector<std::string> filelist_;
static size_t file_idx_;
static std::mutex mutex_for_pick_file_;
std::vector<std::string> filelist_;
size_t* file_idx_;
std::mutex* mutex_for_pick_file_;
// the alias of used slots, and its order is determined by
// data_feed_desc(proto object)
......@@ -141,7 +153,7 @@ class DataFeed {
int batch_size_;
bool finish_init_;
static bool finish_set_filelist_;
bool finish_set_filelist_;
bool finish_start_;
std::string pipe_command_;
};
......@@ -215,8 +227,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const T& ins, std::string* str) = 0;
virtual void DeserializeIns(T* ins, const std::string& str) = 0;
virtual void SerializeIns(const std::vector<T*>& ins, std::string* str) = 0;
virtual void DeserializeIns(std::vector<T>* ins, const std::string& str) = 0;
virtual std::pair<int64_t, int64_t> GetMemoryDataInterval();
int thread_id_;
int thread_num_;
......@@ -284,13 +297,13 @@ class MultiSlotType {
std::string DebugString() {
std::stringstream ss;
ss << "type: " << type_ << "\n";
ss << "offset:\n";
ss << "\ntype: " << type_ << "\n";
ss << "offset: ";
ss << "[";
for (const size_t& i : offset_) {
ss << offset_[i] << ",";
}
ss << "]\ndata:\n[";
ss << "]\ndata: [";
if (type_[0] == 'f') {
for (const float& i : float_feasign_) {
ss << i << ",";
......@@ -356,9 +369,9 @@ class MultiSlotInMemoryDataFeed
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<MultiSlotType>& ins,
virtual void SerializeIns(const std::vector<std::vector<MultiSlotType>*>& ins,
std::string* str);
virtual void DeserializeIns(std::vector<MultiSlotType>* ins,
virtual void DeserializeIns(std::vector<std::vector<MultiSlotType>>* ins,
const std::string& str);
};
......
......@@ -18,6 +18,8 @@
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/framework/io/fs.h"
namespace paddle {
namespace framework {
......@@ -25,12 +27,15 @@ namespace framework {
template <typename T>
DatasetImpl<T>::DatasetImpl() {
thread_num_ = 1;
trainer_num_ = 1;
file_idx_ = 0;
}
template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist;
file_idx_ = 0;
/*
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
......@@ -45,19 +50,34 @@ void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
// not user friendly
template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size();
VLOG(3) << "SetThreadNum thread_num=" << thread_num;
//int file_cnt = filelist_.size();
/*
if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(3) << "DataSet thread num = " << thread_num
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt;
}
}*/
thread_num_ = thread_num;
}
template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
// should inform reader of trainer_num directly
for (auto reader : readers_) {
reader->SetTrainerNum(trainer_num);
}
}
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
std::string cmd = std::string("hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
paddle::framework::hdfs_set_command(cmd);
}
template <typename T>
......@@ -75,6 +95,8 @@ DatasetImpl<T>::GetReaders() {
template <typename T>
void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
......@@ -86,12 +108,17 @@ void DatasetImpl<T>::LoadIntoMemory() {
for (std::thread& t : load_threads) {
t.join();
}
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end";
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
<< ", memory data size=" << memory_data_.size()
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
......@@ -107,23 +134,27 @@ void DatasetImpl<T>::LocalShuffle() {
t.join();
}
std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end";
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
platform::Timer timeline;
timeline.Start();
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) {
......@@ -133,15 +164,32 @@ void DatasetImpl<T>::GlobalShuffle() {
for (std::thread& t : global_shuffle_threads) {
t.join();
}
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end";
std::vector<T>().swap(memory_data_);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0";
int file_cnt = filelist_.size();
int memory_data_size = memory_data_.size();
if (memory_data_size != 0 && thread_num_ > memory_data_size) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", memory data size = " << memory_data_size
<< ". Changing Dataset thread num = " << memory_data_size;
thread_num_ = memory_data_size;
} else if (file_cnt != 0 && thread_num_ > file_cnt) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing Dataset thread num = " << file_cnt;
thread_num_ = file_cnt;
}
VLOG(3) << "thread_num in Readers: " << thread_num_;
VLOG(3) << "readers size: " << readers_.size();
VLOG(3) << "Filelist size in readers: " << filelist_.size();
if (readers_.size() != 0) {
return;
}
......@@ -154,9 +202,10 @@ void DatasetImpl<T>::CreateReaders() {
readers_.back()->SetThreadId(i);
readers_.back()->SetThreadNum(thread_num_);
readers_.back()->SetTrainerNum(trainer_num_);
readers_.back()->SetFileListMutex(&mutex_for_pick_file_);
readers_.back()->SetFileListIndex(&file_idx_);
readers_.back()->SetFileList(filelist_);
}
VLOG(3) << "Filelist size in readers: " << filelist_.size();
readers_[0]->SetFileList(filelist_);
}
template <typename T>
......@@ -184,9 +233,12 @@ void DatasetImpl<T>::DestroyReaders() {
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
// todo random
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
int64_t index = 0;
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length="
<< msg.length();
auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_;
VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg);
return 0;
}
......
......@@ -33,6 +33,8 @@ class Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
virtual void SetThreadNum(int thread_num) = 0;
virtual void SetTrainerNum(int trainer_num) = 0;
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
virtual const std::vector<std::string>& GetFileList() = 0;
virtual int GetThreadNum() = 0;
......@@ -60,6 +62,8 @@ class DatasetImpl : public Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
......@@ -85,8 +89,10 @@ class DatasetImpl : public Dataset {
std::mutex mutex_for_update_memory_data_;
int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
std::vector<std::string> filelist_;
int trainer_num_;
std::vector<std::string> filelist_;
size_t file_idx_;
std::mutex mutex_for_pick_file_;
};
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
......
......@@ -26,12 +26,14 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
workers_.resize(thread_num_);
dataset->CreateReaders();
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
thread_num_ = readers.size();
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
......
......@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
......@@ -203,6 +204,60 @@ void FleetWrapper::PullDenseVarsSync(
#endif
}
void FleetWrapper::PushDenseParamSync(
const ProgramDesc& program, const uint64_t table_id,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
paddle::framework::Scope scope;
auto& block = program.Block(0);
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope.Var(var->Name());
InitializeVariable(ptr, var->GetType());
} else {
auto* ptr = scope.Var(var->Name());
InitializeVariable(ptr, var->GetType());
}
}
auto place = platform::CPUPlace();
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
std::vector<int64_t> dim;
for (auto& var : block.AllVars()) {
if (var->Name() == t) {
dim = var->GetShape();
break;
}
}
int cnt = 1;
for (auto& i: dim) {
cnt *= i;
}
DDim d(std::vector<int64_t>{cnt}.data(), 1);
float* g = tensor->mutable_data<float>(d, place);
CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(LocalRandomEngine()) * init_range;
}
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
CHECK(status == 0) << "push dense param failed, status["
<< status << "]";
}
#endif
}
void FleetWrapper::PushDenseVarsSync(
Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {}
......@@ -269,6 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
continue;
}
LOG(WARNING) << "going to memcpy";
CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
LOG(WARNING) << "show";
......@@ -294,13 +351,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
int FleetWrapper::RegisterClientToClientMsgHandler(
int msg_type, MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pslib_ptr_=" << pslib_ptr_;
VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr;
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler);
return pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler);
#else
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler"
<< " does nothing when no pslib";
......@@ -308,15 +365,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
return 0;
}
int FleetWrapper::SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg) {
std::future<int32_t> FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) {
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg);
return pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg);
#else
VLOG(0) << "FleetWrapper::SendClientToClientMsg"
<< " does nothing when no pslib";
#endif
return 0;
return std::future<int32_t>();
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
......@@ -336,10 +393,12 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
}
template <typename T>
void FleetWrapper::Serialize(const T& t, std::string* str) {
void FleetWrapper::Serialize(const std::vector<T*>& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
ar << t;
for (size_t i = 0; i < t.size(); ++i) {
ar << *(t[i]);
}
*str = std::string(ar.buffer(), ar.length());
#else
VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib";
......@@ -347,20 +406,30 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
}
template <typename T>
void FleetWrapper::Deserialize(T* t, const std::string& str) {
void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB
if (str.length() == 0) {
return;
}
paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
*t = ar.get<T>();
if (ar.cursor() == ar.finish()) {
return;
}
while (ar.cursor() < ar.finish()) {
t->push_back(ar.get<T>());
}
CHECK(ar.cursor() == ar.finish());
VLOG(3) << "Deserialize size " << t->size();
#else
VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib";
#endif
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<MultiSlotType>&, std::string*);
template void FleetWrapper::Deserialize(std::vector<MultiSlotType>*,
const std::string&);
const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
std::vector<std::vector<MultiSlotType>>*, const std::string&);
} // end namespace framework
} // end namespace paddle
......@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
......@@ -71,6 +72,10 @@ class FleetWrapper {
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status);
void PushDenseParamSync(
const ProgramDesc& program, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names,
// Param<out>: push_sparse_status
......@@ -119,16 +124,15 @@ class FleetWrapper {
typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
int SendClientToClientMsg(int msg_type,
std::future<int32_t> SendClientToClientMsg(int msg_type,
int to_client_id,
const std::string& msg);
std::default_random_engine& LocalRandomEngine();
template <typename T>
void Serialize(const T& t, std::string* str);
void Serialize(const std::vector<T*>& t, std::string* str);
template <typename T>
void Deserialize(T* t, const std::string& str);
void Deserialize(std::vector<T>* t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
......
......@@ -26,13 +26,15 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
// get filelist from trainer_desc here
workers_.resize(thread_num_);
VLOG(3) << "worker thread num: " << thread_num_;
dataset->CreateReaders();
VLOG(3) << "readers created";
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
// change thread num to readers num
thread_num_ = readers.size();
VLOG(3) << "worker thread num: " << thread_num_;
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
......
......@@ -49,7 +49,7 @@ void BindAsyncExecutor(py::module* m) {
new framework::AsyncExecutor(scope, place));
}))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile)
.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
//.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
.def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer)
......
......@@ -50,6 +50,7 @@ void BindDataset(py::module* m) {
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
......
......@@ -47,6 +47,7 @@ void BindFleetWrapper(py::module* m) {
.def("init_server", &framework::FleetWrapper::InitServer)
.def("run_server", &framework::FleetWrapper::RunServer)
.def("init_worker", &framework::FleetWrapper::InitWorker)
.def("init_model", &framework::FleetWrapper::PushDenseParamSync)
.def("stop_server", &framework::FleetWrapper::StopServer)
.def("gather_servers", &framework::FleetWrapper::GatherServers);
} // end FleetWrapper
......
......@@ -86,6 +86,9 @@ class DatasetBase(object):
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
)
def set_hdfs_config(self, fs_name, fs_ugi):
self.dataset.set_hdfs_config(fs_name, fs_ugi)
def _prepare_to_run(self):
self.dataset.set_data_feed_desc(self.desc())
......@@ -115,11 +118,15 @@ class InMemoryDataset(DatasetBase):
def local_shuffle(self):
self.dataset.local_shuffle()
def global_shuffle(self):
from .distributed import ps_instance
instance = ps_instance.PaddlePSInstance(1, 2)
self.dataset.set_trainer_num(instance.get_worker_num())
def global_shuffle(self, fleet=None):
trainer_num = 1
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
trainer_num = fleet.worker_num()
self.dataset.set_trainer_num(trainer_num)
self.dataset.global_shuffle()
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
class QueueDataset(DatasetBase):
......@@ -130,5 +137,5 @@ class QueueDataset(DatasetBase):
def local_shuffle(self):
pass
def global_shuffle(self):
def global_shuffle(self, fleet=None):
pass
......@@ -170,7 +170,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
if self._check_role_generation():
if self.is_worker():
return self.get_size()
return self.get_size() / 2;
return 0
def server_num(self):
......@@ -179,7 +179,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
if self._check_role_generation():
if self.is_server():
return self.get_size()
return self.get_size() / 2;
return 0
def worker_index(self):
......
......@@ -122,7 +122,7 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
def init_worker(self):
def init_worker(self, program):
"""
init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect
......@@ -143,6 +143,19 @@ class Fleet(object):
self.role_maker_.get_rank())
self.role_maker_.barrier_all()
self.role_maker_.barrier_worker()
if self.role_maker_.is_first_worker():
tables = self._dist_desc.trainer_param.dense_table._values
for i in range(0, len(tables)):
table = tables[i];
var_name_list = []
for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i])
#print "table id ", table.table_id
#print "var_name_list ", var_name_list
self._fleet_ptr.init_model(program.desc,
int(table.table_id),
var_name_list)
self.role_maker_.barrier_worker()
else:
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册