提交 a5b1a0e1 编写于 作者: X xujiaqi01 提交者: dongdaxiang

support multi dataset && add init model && fix bug

上级 3c65cc1b
...@@ -155,7 +155,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -155,7 +155,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
} }
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") { if (mode == "mpi") {
_pull_dense_thread->stop(); // todo ?
//_pull_dense_thread->stop();
} }
#endif #endif
VLOG(3) << "start to run from files in async_executor"; VLOG(3) << "start to run from files in async_executor";
......
...@@ -23,15 +23,11 @@ limitations under the License. */ ...@@ -23,15 +23,11 @@ limitations under the License. */
#include "io/shell.h" #include "io/shell.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::vector<std::string> DataFeed::filelist_;
size_t DataFeed::file_idx_;
std::mutex DataFeed::mutex_for_pick_file_;
bool DataFeed::finish_set_filelist_;
void DataFeed::AddFeedVar(Variable* var, const std::string& name) { void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit(); CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
...@@ -42,7 +38,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) { ...@@ -42,7 +38,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
} }
bool DataFeed::SetFileList(const std::vector<std::string>& files) { bool DataFeed::SetFileList(const std::vector<std::string>& files) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_); std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
CheckInit(); CheckInit();
// Do not set finish_set_filelist_ flag, // Do not set finish_set_filelist_ flag,
// since a user may set file many times after init reader // since a user may set file many times after init reader
...@@ -52,9 +48,8 @@ bool DataFeed::SetFileList(const std::vector<std::string>& files) { ...@@ -52,9 +48,8 @@ bool DataFeed::SetFileList(const std::vector<std::string>& files) {
return false; return false;
} }
*/ */
PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); //PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
filelist_.assign(files.begin(), files.end()); filelist_.assign(files.begin(), files.end());
file_idx_ = 0;
finish_set_filelist_ = true; finish_set_filelist_ = true;
return true; return true;
...@@ -66,13 +61,17 @@ void DataFeed::SetBatchSize(int batch_size) { ...@@ -66,13 +61,17 @@ void DataFeed::SetBatchSize(int batch_size) {
} }
bool DataFeed::PickOneFile(std::string* filename) { bool DataFeed::PickOneFile(std::string* filename) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_); PADDLE_ENFORCE(mutex_for_pick_file_ != nullptr,
if (file_idx_ == filelist_.size()) { "should call SetFileListMutex before PickOneFile");
PADDLE_ENFORCE(file_idx_ != nullptr,
"should call SetFileListIndex before PickOneFile");
std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
if (*file_idx_ == filelist_.size()) {
VLOG(3) << "DataFeed::PickOneFile no more file to pick"; VLOG(3) << "DataFeed::PickOneFile no more file to pick";
return false; return false;
} }
VLOG(3) << "file_idx_=" << file_idx_; VLOG(3) << "file_idx_=" << *file_idx_;
*filename = filelist_[file_idx_++]; *filename = filelist_[(*file_idx_)++];
// LOG(ERROR) << "pick file:" << *filename; // LOG(ERROR) << "pick file:" << *filename;
return true; return true;
} }
...@@ -150,7 +149,11 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -150,7 +149,11 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_ = 0; cur_channel_ = 0;
shuffled_ins_ = std::make_shared<paddle::framework::BlockingQueue<T>>(); shuffled_ins_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
shuffled_ins_out_ = std::make_shared<paddle::framework::BlockingQueue<T>>(); shuffled_ins_out_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
fleet_send_batch_size_ = 10000; fleet_send_batch_size_ = 80000;
memory_data_ = nullptr;
mutex_for_update_memory_data_ = nullptr;
this->file_idx_ = nullptr;
this->mutex_for_pick_file_ = nullptr;
} }
template <typename T> template <typename T>
...@@ -192,6 +195,8 @@ int InMemoryDataFeed<T>::Next() { ...@@ -192,6 +195,8 @@ int InMemoryDataFeed<T>::Next() {
out_channel->Push(std::move(instance)); out_channel->Push(std::move(instance));
} }
DataFeed::batch_size_ = index; DataFeed::batch_size_ = index;
VLOG(3) << "batch_size_=" << DataFeed::batch_size_
<< ", thread_id=" << thread_id_;
if (DataFeed::batch_size_ != 0) { if (DataFeed::batch_size_ != 0) {
PutToFeedVec(ins_vec); PutToFeedVec(ins_vec);
} else { } else {
...@@ -227,25 +232,22 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) { ...@@ -227,25 +232,22 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
template <typename T> template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) { void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins; std::vector<T> ins;
DeserializeIns(&ins, ins_str); DeserializeIns(&ins, ins_str);
shuffled_ins_->Push(std::move(ins)); shuffled_ins_->Extend(std::move(ins));
VLOG(3) << "PutInsToChannel put ins num=" << ins.size()
<< " to channel, channel size=" << shuffled_ins_->Size()
<< " thread_id=" << thread_id_;
} }
template <typename T> template <typename T>
void InMemoryDataFeed<T>::FillMemoryDataToChannel() { void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_; VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_;
int64_t start = 0; auto interval = GetMemoryDataInterval();
int64_t end = 0; VLOG(3) << "memory data size=" << memory_data_->size()
int64_t size = memory_data_->size(); << ", fill data from [" << interval.first << ", "
VLOG(3) << "memory_data size=" << size; << interval.second << "), thread_id=" << thread_id_;
for (int64_t i = 0; i <= static_cast<int64_t>(thread_id_); ++i) { for (int64_t i = interval.first; i < interval.second; ++i) {
int64_t len = size / static_cast<int64_t>(thread_num_) +
(i < (size % static_cast<int64_t>(thread_num_)));
start = end;
end += len;
}
for (int64_t i = start; i < end; ++i) {
T& t = (*memory_data_)[i]; T& t = (*memory_data_)[i];
shuffled_ins_->Push(std::move(t)); shuffled_ins_->Push(std::move(t));
} }
...@@ -256,14 +258,19 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() { ...@@ -256,14 +258,19 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
VLOG(3) << "FillChannelToMemoryData, thread_id=" << thread_id_; VLOG(3) << "FillChannelToMemoryData, thread_id=" << thread_id_;
std::vector<T> local_vec; std::vector<T> local_vec;
std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr; std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> pre_channel = nullptr;
if (cur_channel_ == 0) { if (cur_channel_ == 0) {
channel = shuffled_ins_; channel = shuffled_ins_;
pre_channel = shuffled_ins_out_;
} else { } else {
channel = shuffled_ins_out_; channel = shuffled_ins_out_;
pre_channel = shuffled_ins_;
} }
CHECK(channel != nullptr); CHECK(channel != nullptr);
CHECK(pre_channel != nullptr);
CHECK(pre_channel->Size() == 0);
local_vec.resize(channel->Size()); local_vec.resize(channel->Size());
for (int64_t i = 0; i < channel->Size(); ++i) { for (int64_t i = 0; i < local_vec.size(); ++i) {
channel->Pop(local_vec[i]); channel->Pop(local_vec[i]);
} }
VLOG(3) << "local_vec size=" << local_vec.size() <<", thread_id=" << thread_id_; VLOG(3) << "local_vec size=" << local_vec.size() <<", thread_id=" << thread_id_;
...@@ -289,20 +296,32 @@ void InMemoryDataFeed<T>::LoadIntoMemory() { ...@@ -289,20 +296,32 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
int err_no = 0; int err_no = 0;
PrivateQueueDataFeed<T>::fp_ = PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_); fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
CHECK(PrivateQueueDataFeed<T>::fp_ != nullptr);
__fsetlocking(&*PrivateQueueDataFeed<T>::fp_, FSETLOCKING_BYCALLER); __fsetlocking(&*PrivateQueueDataFeed<T>::fp_, FSETLOCKING_BYCALLER);
T instance; T instance;
platform::Timer timeline;
timeline.Start();
while (ParseOneInstanceFromPipe(&instance)) { while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance); local_vec.push_back(instance);
} }
timeline.Pause();
VLOG(3) << "LoadIntoMemory() read all lines, file=" VLOG(3) << "LoadIntoMemory() read all lines, file="
<< filename <<", thread_id=" << thread_id_; << filename << ", cost time=" << timeline.ElapsedSec()
<< " seconds, thread_id=" << thread_id_;
{ {
std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_); std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_);
timeline.Start();
memory_data_->insert(memory_data_->end(), memory_data_->insert(memory_data_->end(),
local_vec.begin(), local_vec.end()); std::make_move_iterator(local_vec.begin()),
std::make_move_iterator(local_vec.end()));
timeline.Pause();
VLOG(3) << "LoadIntoMemory() memory_data insert, cost time="
<< timeline.ElapsedSec() << " seconds, thread_id="
<< thread_id_;
} }
std::vector<T>().swap(local_vec); local_vec.clear();
} }
std::vector<T>().swap(local_vec);
VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_; VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_;
} }
...@@ -315,30 +334,66 @@ void InMemoryDataFeed<T>::LocalShuffle() { ...@@ -315,30 +334,66 @@ void InMemoryDataFeed<T>::LocalShuffle() {
template <typename T> template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle() { void InMemoryDataFeed<T>::GlobalShuffle() {
VLOG(3) << "GlobalShuffle(), thread_id=" << thread_id_; VLOG(3) << "GlobalShuffle() begin, thread_id=" << thread_id_;
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::string> send_str_vec(trainer_num_); std::vector<std::vector<T*>> send_vec(trainer_num_);
for (int64_t i = 0; i < memory_data_->size(); ++i) { for (auto& vec : send_vec) {
// todo get ins id vec.reserve(fleet_send_batch_size_);
}
std::vector<std::future<int32_t>> total_status;
auto interval = GetMemoryDataInterval();
VLOG(3) << "global shuffle data from [" << interval.first << ", "
<< interval.second << "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) {
// if get ins id, can also use hash
// std::string ins_id = memory_data_[i].ins_id; // std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t random_num = fleet_ptr->LocalRandomEngine()(); int64_t random_num = fleet_ptr->LocalRandomEngine()();
int64_t node_id = random_num % trainer_num_; int64_t node_id = random_num % trainer_num_;
std::string str; send_vec[node_id].push_back(&((*memory_data_)[i]));
SerializeIns((*memory_data_)[i], &str);
send_str_vec[node_id] += str;
if (i % fleet_send_batch_size_ == 0 && i != 0) { if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_str_vec.size(); ++j) { for (int j = 0; j < send_vec.size(); ++j) {
fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]); std::string send_str;
send_str_vec[j] = ""; SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
<< ", ins num=" << send_vec[j].size() << " to node_id="
<< j << ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
send_vec[j].clear();
total_status.push_back(std::move(ret));
} }
} }
} }
for (int j = 0; j < send_str_vec.size(); ++j) { for (int j = 0; j < send_vec.size(); ++j) {
if (send_str_vec[j].length() != 0) { if (send_vec[j].size() != 0) {
fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]); std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
<< " to node_id=" << j << ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
total_status.push_back(std::move(ret));
} }
std::vector<T*>().swap(send_vec[j]);
}
for (auto& t : total_status) {
t.wait();
} }
VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_;
}
template <typename T>
std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() {
int64_t start = 0;
int64_t end = 0;
int64_t size = memory_data_->size();
for (int64_t i = 0; i <= static_cast<int64_t>(thread_id_); ++i) {
int64_t len = size / static_cast<int64_t>(thread_num_) +
(i < (size % static_cast<int64_t>(thread_num_)));
start = end;
end += len;
}
return std::make_pair(start, end);
} }
// explicit instantiation // explicit instantiation
...@@ -519,7 +574,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( ...@@ -519,7 +574,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
const char* str = reader.get(); const char* str = reader.get();
std::string line = std::string(str); std::string line = std::string(str);
VLOG(3) << line; //VLOG(3) << line;
char* endptr = const_cast<char*>(str); char* endptr = const_cast<char*>(str);
int pos = 0; int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
...@@ -695,7 +750,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe( ...@@ -695,7 +750,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
const char* str = reader.get(); const char* str = reader.get();
std::string line = std::string(str); std::string line = std::string(str);
VLOG(3) << line; //VLOG(3) << line;
char* endptr = const_cast<char*>(str); char* endptr = const_cast<char*>(str);
int pos = 0; int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
...@@ -830,13 +885,15 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -830,13 +885,15 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle // todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns( void MultiSlotInMemoryDataFeed::SerializeIns(
const std::vector<MultiSlotType>& ins, std::string* str) { const std::vector<std::vector<MultiSlotType>*>& ins,
std::string* str) {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Serialize(ins, str); fleet_ptr->Serialize(ins, str);
} }
// todo deserialize ins in global shuffle // todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>* ins, void MultiSlotInMemoryDataFeed::DeserializeIns(
const std::string& str) { std::vector<std::vector<MultiSlotType>>* ins,
const std::string& str) {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Deserialize(ins, str); fleet_ptr->Deserialize(ins, str);
} }
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <future>
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -52,7 +53,10 @@ namespace framework { ...@@ -52,7 +53,10 @@ namespace framework {
// } // }
class DataFeed { class DataFeed {
public: public:
DataFeed() {} DataFeed() {
mutex_for_pick_file_ = nullptr;
file_idx_ = nullptr;
}
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0; virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
...@@ -89,6 +93,12 @@ class DataFeed { ...@@ -89,6 +93,12 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) { } virtual void SetThreadNum(int thread_num) { }
// This function will do nothing at default // This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) { } virtual void SetTrainerNum(int trainer_num) { }
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
virtual void SetFileListIndex(size_t* file_index) {
file_idx_ = file_index;
}
virtual void LoadIntoMemory() { virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented."); PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
} }
...@@ -100,7 +110,9 @@ class DataFeed { ...@@ -100,7 +110,9 @@ class DataFeed {
} }
// This function will do nothing at default // This function will do nothing at default
virtual void FillMemoryDataToChannel() { } virtual void FillMemoryDataToChannel() { }
// This function will do nothing at default
virtual void FillChannelToMemoryData() { } virtual void FillChannelToMemoryData() { }
// This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) { } virtual void PutInsToChannel(const std::string& ins_str) { }
protected: protected:
...@@ -116,9 +128,9 @@ class DataFeed { ...@@ -116,9 +128,9 @@ class DataFeed {
// safe). // safe).
virtual bool PickOneFile(std::string* filename); virtual bool PickOneFile(std::string* filename);
static std::vector<std::string> filelist_; std::vector<std::string> filelist_;
static size_t file_idx_; size_t* file_idx_;
static std::mutex mutex_for_pick_file_; std::mutex* mutex_for_pick_file_;
// the alias of used slots, and its order is determined by // the alias of used slots, and its order is determined by
// data_feed_desc(proto object) // data_feed_desc(proto object)
...@@ -141,7 +153,7 @@ class DataFeed { ...@@ -141,7 +153,7 @@ class DataFeed {
int batch_size_; int batch_size_;
bool finish_init_; bool finish_init_;
static bool finish_set_filelist_; bool finish_set_filelist_;
bool finish_start_; bool finish_start_;
std::string pipe_command_; std::string pipe_command_;
}; };
...@@ -215,8 +227,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -215,8 +227,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual bool ParseOneInstance(T* instance) = 0; virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0; virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0; virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const T& ins, std::string* str) = 0; virtual void SerializeIns(const std::vector<T*>& ins, std::string* str) = 0;
virtual void DeserializeIns(T* ins, const std::string& str) = 0; virtual void DeserializeIns(std::vector<T>* ins, const std::string& str) = 0;
virtual std::pair<int64_t, int64_t> GetMemoryDataInterval();
int thread_id_; int thread_id_;
int thread_num_; int thread_num_;
...@@ -284,13 +297,13 @@ class MultiSlotType { ...@@ -284,13 +297,13 @@ class MultiSlotType {
std::string DebugString() { std::string DebugString() {
std::stringstream ss; std::stringstream ss;
ss << "type: " << type_ << "\n"; ss << "\ntype: " << type_ << "\n";
ss << "offset:\n"; ss << "offset: ";
ss << "["; ss << "[";
for (const size_t& i : offset_) { for (const size_t& i : offset_) {
ss << offset_[i] << ","; ss << offset_[i] << ",";
} }
ss << "]\ndata:\n["; ss << "]\ndata: [";
if (type_[0] == 'f') { if (type_[0] == 'f') {
for (const float& i : float_feasign_) { for (const float& i : float_feasign_) {
ss << i << ","; ss << i << ",";
...@@ -356,9 +369,9 @@ class MultiSlotInMemoryDataFeed ...@@ -356,9 +369,9 @@ class MultiSlotInMemoryDataFeed
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance); virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance); virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec); virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<MultiSlotType>& ins, virtual void SerializeIns(const std::vector<std::vector<MultiSlotType>*>& ins,
std::string* str); std::string* str);
virtual void DeserializeIns(std::vector<MultiSlotType>* ins, virtual void DeserializeIns(std::vector<std::vector<MultiSlotType>>* ins,
const std::string& str); const std::string& str);
}; };
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/framework/io/fs.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -25,12 +27,15 @@ namespace framework { ...@@ -25,12 +27,15 @@ namespace framework {
template <typename T> template <typename T>
DatasetImpl<T>::DatasetImpl() { DatasetImpl<T>::DatasetImpl() {
thread_num_ = 1; thread_num_ = 1;
trainer_num_ = 1;
file_idx_ = 0;
} }
template <typename T> template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) { void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size(); VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist; filelist_ = filelist;
file_idx_ = 0;
/* /*
int file_cnt = filelist_.size(); int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) { if (thread_num_ > file_cnt) {
...@@ -45,19 +50,34 @@ void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) { ...@@ -45,19 +50,34 @@ void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
// not user friendly // not user friendly
template <typename T> template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) { void DatasetImpl<T>::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size(); VLOG(3) << "SetThreadNum thread_num=" << thread_num;
//int file_cnt = filelist_.size();
/*
if (file_cnt != 0 && thread_num > file_cnt) { if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(3) << "DataSet thread num = " << thread_num VLOG(3) << "DataSet thread num = " << thread_num
<< ", file num = " << file_cnt << ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt; << ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt; thread_num = file_cnt;
} }*/
thread_num_ = thread_num; thread_num_ = thread_num;
} }
template <typename T> template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) { void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num; trainer_num_ = trainer_num;
// should inform reader of trainer_num directly
for (auto reader : readers_) {
reader->SetTrainerNum(trainer_num);
}
}
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
std::string cmd = std::string("hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
paddle::framework::hdfs_set_command(cmd);
} }
template <typename T> template <typename T>
...@@ -75,6 +95,8 @@ DatasetImpl<T>::GetReaders() { ...@@ -75,6 +95,8 @@ DatasetImpl<T>::GetReaders() {
template <typename T> template <typename T>
void DatasetImpl<T>::LoadIntoMemory() { void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin"; VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) { if (readers_.size() == 0) {
CreateReaders(); CreateReaders();
} }
...@@ -86,12 +108,17 @@ void DatasetImpl<T>::LoadIntoMemory() { ...@@ -86,12 +108,17 @@ void DatasetImpl<T>::LoadIntoMemory() {
for (std::thread& t : load_threads) { for (std::thread& t : load_threads) {
t.join(); t.join();
} }
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"; timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
<< ", memory data size=" << memory_data_.size()
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
} }
template <typename T> template <typename T>
void DatasetImpl<T>::LocalShuffle() { void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin"; VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) { if (readers_.size() == 0) {
CreateReaders(); CreateReaders();
} }
...@@ -107,23 +134,27 @@ void DatasetImpl<T>::LocalShuffle() { ...@@ -107,23 +134,27 @@ void DatasetImpl<T>::LocalShuffle() {
t.join(); t.join();
} }
std::vector<T>().swap(memory_data_); std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end"; timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
} }
template <typename T> template <typename T>
void DatasetImpl<T>::GlobalShuffle() { void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin"; VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
if (readers_.size() == 0) { platform::Timer timeline;
CreateReaders(); timeline.Start();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler"; VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler( fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int { 0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg); return this->ReceiveFromClient(msg_type, client_id, msg);
}); });
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
VLOG(3) << "start global shuffle threads"; VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads; std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
...@@ -133,15 +164,32 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -133,15 +164,32 @@ void DatasetImpl<T>::GlobalShuffle() {
for (std::thread& t : global_shuffle_threads) { for (std::thread& t : global_shuffle_threads) {
t.join(); t.join();
} }
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end"; std::vector<T>().swap(memory_data_);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
} }
template <typename T> template <typename T>
void DatasetImpl<T>::CreateReaders() { void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()"; VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0"; CHECK(thread_num_ > 0) << "thread_num should > 0";
int file_cnt = filelist_.size();
int memory_data_size = memory_data_.size();
if (memory_data_size != 0 && thread_num_ > memory_data_size) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", memory data size = " << memory_data_size
<< ". Changing Dataset thread num = " << memory_data_size;
thread_num_ = memory_data_size;
} else if (file_cnt != 0 && thread_num_ > file_cnt) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing Dataset thread num = " << file_cnt;
thread_num_ = file_cnt;
}
VLOG(3) << "thread_num in Readers: " << thread_num_; VLOG(3) << "thread_num in Readers: " << thread_num_;
VLOG(3) << "readers size: " << readers_.size(); VLOG(3) << "readers size: " << readers_.size();
VLOG(3) << "Filelist size in readers: " << filelist_.size();
if (readers_.size() != 0) { if (readers_.size() != 0) {
return; return;
} }
...@@ -154,9 +202,10 @@ void DatasetImpl<T>::CreateReaders() { ...@@ -154,9 +202,10 @@ void DatasetImpl<T>::CreateReaders() {
readers_.back()->SetThreadId(i); readers_.back()->SetThreadId(i);
readers_.back()->SetThreadNum(thread_num_); readers_.back()->SetThreadNum(thread_num_);
readers_.back()->SetTrainerNum(trainer_num_); readers_.back()->SetTrainerNum(trainer_num_);
readers_.back()->SetFileListMutex(&mutex_for_pick_file_);
readers_.back()->SetFileListIndex(&file_idx_);
readers_.back()->SetFileList(filelist_);
} }
VLOG(3) << "Filelist size in readers: " << filelist_.size();
readers_[0]->SetFileList(filelist_);
} }
template <typename T> template <typename T>
...@@ -184,9 +233,12 @@ void DatasetImpl<T>::DestroyReaders() { ...@@ -184,9 +233,12 @@ void DatasetImpl<T>::DestroyReaders() {
template <typename T> template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) { const std::string& msg) {
// todo random VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
// int64_t index = paddle::ps::local_random_engine()() % thread_num_; << ", client_id=" << client_id << ", msg length="
int64_t index = 0; << msg.length();
auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_;
VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg); readers_[index]->PutInsToChannel(msg);
return 0; return 0;
} }
......
...@@ -33,6 +33,8 @@ class Dataset { ...@@ -33,6 +33,8 @@ class Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist) = 0; virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
virtual void SetThreadNum(int thread_num) = 0; virtual void SetThreadNum(int thread_num) = 0;
virtual void SetTrainerNum(int trainer_num) = 0; virtual void SetTrainerNum(int trainer_num) = 0;
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
virtual const std::vector<std::string>& GetFileList() = 0; virtual const std::vector<std::string>& GetFileList() = 0;
virtual int GetThreadNum() = 0; virtual int GetThreadNum() = 0;
...@@ -60,6 +62,8 @@ class DatasetImpl : public Dataset { ...@@ -60,6 +62,8 @@ class DatasetImpl : public Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist); virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num); virtual void SetTrainerNum(int trainer_num);
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str); virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual const std::vector<std::string>& GetFileList() { return filelist_; } virtual const std::vector<std::string>& GetFileList() { return filelist_; }
...@@ -85,8 +89,10 @@ class DatasetImpl : public Dataset { ...@@ -85,8 +89,10 @@ class DatasetImpl : public Dataset {
std::mutex mutex_for_update_memory_data_; std::mutex mutex_for_update_memory_data_;
int thread_num_; int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_; paddle::framework::DataFeedDesc data_feed_desc_;
std::vector<std::string> filelist_;
int trainer_num_; int trainer_num_;
std::vector<std::string> filelist_;
size_t file_idx_;
std::mutex mutex_for_pick_file_;
}; };
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> { class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
......
...@@ -26,12 +26,14 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -26,12 +26,14 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) { Dataset* dataset) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
SetDataset(dataset); SetDataset(dataset);
workers_.resize(thread_num_);
dataset->CreateReaders(); dataset->CreateReaders();
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers = const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders(); dataset->GetReaders();
thread_num_ = readers.size();
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
......
...@@ -29,6 +29,7 @@ limitations under the License. */ ...@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility> #include <utility>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -203,6 +204,60 @@ void FleetWrapper::PullDenseVarsSync( ...@@ -203,6 +204,60 @@ void FleetWrapper::PullDenseVarsSync(
#endif #endif
} }
void FleetWrapper::PushDenseParamSync(
const ProgramDesc& program, const uint64_t table_id,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
paddle::framework::Scope scope;
auto& block = program.Block(0);
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope.Var(var->Name());
InitializeVariable(ptr, var->GetType());
} else {
auto* ptr = scope.Var(var->Name());
InitializeVariable(ptr, var->GetType());
}
}
auto place = platform::CPUPlace();
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
std::vector<int64_t> dim;
for (auto& var : block.AllVars()) {
if (var->Name() == t) {
dim = var->GetShape();
break;
}
}
int cnt = 1;
for (auto& i: dim) {
cnt *= i;
}
DDim d(std::vector<int64_t>{cnt}.data(), 1);
float* g = tensor->mutable_data<float>(d, place);
CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(LocalRandomEngine()) * init_range;
}
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
CHECK(status == 0) << "push dense param failed, status["
<< status << "]";
}
#endif
}
void FleetWrapper::PushDenseVarsSync( void FleetWrapper::PushDenseVarsSync(
Scope* scope, const uint64_t table_id, Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {} const std::vector<std::string>& var_names) {}
...@@ -269,6 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -269,6 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
continue; continue;
} }
LOG(WARNING) << "going to memcpy"; LOG(WARNING) << "going to memcpy";
CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset, g, memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
LOG(WARNING) << "show"; LOG(WARNING) << "show";
...@@ -294,13 +351,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -294,13 +351,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif #endif
} }
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, int FleetWrapper::RegisterClientToClientMsgHandler(
MsgHandlerFunc handler) { int msg_type, MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pslib_ptr_=" << pslib_ptr_; VLOG(3) << "pslib_ptr_=" << pslib_ptr_;
VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr; VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr;
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler); return pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler);
#else #else
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler" VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler"
<< " does nothing when no pslib"; << " does nothing when no pslib";
...@@ -308,15 +365,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, ...@@ -308,15 +365,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
return 0; return 0;
} }
int FleetWrapper::SendClientToClientMsg(int msg_type, int to_client_id, std::future<int32_t> FleetWrapper::SendClientToClientMsg(
const std::string& msg) { int msg_type, int to_client_id, const std::string& msg) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg); return pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg);
#else #else
VLOG(0) << "FleetWrapper::SendClientToClientMsg" VLOG(0) << "FleetWrapper::SendClientToClientMsg"
<< " does nothing when no pslib"; << " does nothing when no pslib";
#endif #endif
return 0; return std::future<int32_t>();
} }
std::default_random_engine& FleetWrapper::LocalRandomEngine() { std::default_random_engine& FleetWrapper::LocalRandomEngine() {
...@@ -336,10 +393,12 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() { ...@@ -336,10 +393,12 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
} }
template <typename T> template <typename T>
void FleetWrapper::Serialize(const T& t, std::string* str) { void FleetWrapper::Serialize(const std::vector<T*>& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar; paddle::ps::BinaryArchive ar;
ar << t; for (size_t i = 0; i < t.size(); ++i) {
ar << *(t[i]);
}
*str = std::string(ar.buffer(), ar.length()); *str = std::string(ar.buffer(), ar.length());
#else #else
VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib"; VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib";
...@@ -347,20 +406,30 @@ void FleetWrapper::Serialize(const T& t, std::string* str) { ...@@ -347,20 +406,30 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
} }
template <typename T> template <typename T>
void FleetWrapper::Deserialize(T* t, const std::string& str) { void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (str.length() == 0) {
return;
}
paddle::ps::BinaryArchive ar; paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr); ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
*t = ar.get<T>(); if (ar.cursor() == ar.finish()) {
return;
}
while (ar.cursor() < ar.finish()) {
t->push_back(ar.get<T>());
}
CHECK(ar.cursor() == ar.finish());
VLOG(3) << "Deserialize size " << t->size();
#else #else
VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib"; VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib";
#endif #endif
} }
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>( template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<MultiSlotType>&, std::string*); const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize(std::vector<MultiSlotType>*, template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
const std::string&); std::vector<std::vector<MultiSlotType>>*, const std::string&);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/framework/program_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -71,6 +72,10 @@ class FleetWrapper { ...@@ -71,6 +72,10 @@ class FleetWrapper {
const std::vector<std::string>& var_names, const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status); std::vector<::std::future<int32_t>>* pull_dense_status);
void PushDenseParamSync(
const ProgramDesc& program, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push dense variables to server in async mode // Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names, // Param<in>: scope, table_id, var_names,
// Param<out>: push_sparse_status // Param<out>: push_sparse_status
...@@ -119,16 +124,15 @@ class FleetWrapper { ...@@ -119,16 +124,15 @@ class FleetWrapper {
typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc; typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
int SendClientToClientMsg(int msg_type, std::future<int32_t> SendClientToClientMsg(int msg_type,
int to_client_id, int to_client_id,
const std::string& msg); const std::string& msg);
std::default_random_engine& LocalRandomEngine(); std::default_random_engine& LocalRandomEngine();
template <typename T> template <typename T>
void Serialize(const T& t, std::string* str); void Serialize(const std::vector<T*>& t, std::string* str);
template <typename T> template <typename T>
void Deserialize(T* t, const std::string& str); void Deserialize(std::vector<T>* t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() { static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) { if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper()); s_instance_.reset(new paddle::framework::FleetWrapper());
......
...@@ -26,13 +26,15 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -26,13 +26,15 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
SetDataset(dataset); SetDataset(dataset);
// get filelist from trainer_desc here // get filelist from trainer_desc here
workers_.resize(thread_num_);
VLOG(3) << "worker thread num: " << thread_num_;
dataset->CreateReaders(); dataset->CreateReaders();
VLOG(3) << "readers created"; VLOG(3) << "readers created";
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers = const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders(); dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size(); VLOG(3) << "readers num: " << readers.size();
// change thread num to readers num
thread_num_ = readers.size();
VLOG(3) << "worker thread num: " << thread_num_;
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
......
...@@ -49,7 +49,7 @@ void BindAsyncExecutor(py::module* m) { ...@@ -49,7 +49,7 @@ void BindAsyncExecutor(py::module* m) {
new framework::AsyncExecutor(scope, place)); new framework::AsyncExecutor(scope, place));
})) }))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile) .def("run_from_files", &framework::AsyncExecutor::RunFromFile)
.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset) //.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
.def("init_server", &framework::AsyncExecutor::InitServer) .def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker) .def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer) .def("start_server", &framework::AsyncExecutor::StartServer)
......
...@@ -50,6 +50,7 @@ void BindDataset(py::module* m) { ...@@ -50,6 +50,7 @@ void BindDataset(py::module* m) {
.def("set_filelist", &framework::Dataset::SetFileList) .def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum) .def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum) .def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory) .def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle) .def("local_shuffle", &framework::Dataset::LocalShuffle)
......
...@@ -47,6 +47,7 @@ void BindFleetWrapper(py::module* m) { ...@@ -47,6 +47,7 @@ void BindFleetWrapper(py::module* m) {
.def("init_server", &framework::FleetWrapper::InitServer) .def("init_server", &framework::FleetWrapper::InitServer)
.def("run_server", &framework::FleetWrapper::RunServer) .def("run_server", &framework::FleetWrapper::RunServer)
.def("init_worker", &framework::FleetWrapper::InitWorker) .def("init_worker", &framework::FleetWrapper::InitWorker)
.def("init_model", &framework::FleetWrapper::PushDenseParamSync)
.def("stop_server", &framework::FleetWrapper::StopServer) .def("stop_server", &framework::FleetWrapper::StopServer)
.def("gather_servers", &framework::FleetWrapper::GatherServers); .def("gather_servers", &framework::FleetWrapper::GatherServers);
} // end FleetWrapper } // end FleetWrapper
......
...@@ -86,6 +86,9 @@ class DatasetBase(object): ...@@ -86,6 +86,9 @@ class DatasetBase(object):
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64" "Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
) )
def set_hdfs_config(self, fs_name, fs_ugi):
self.dataset.set_hdfs_config(fs_name, fs_ugi)
def _prepare_to_run(self): def _prepare_to_run(self):
self.dataset.set_data_feed_desc(self.desc()) self.dataset.set_data_feed_desc(self.desc())
...@@ -115,11 +118,15 @@ class InMemoryDataset(DatasetBase): ...@@ -115,11 +118,15 @@ class InMemoryDataset(DatasetBase):
def local_shuffle(self): def local_shuffle(self):
self.dataset.local_shuffle() self.dataset.local_shuffle()
def global_shuffle(self): def global_shuffle(self, fleet=None):
from .distributed import ps_instance trainer_num = 1
instance = ps_instance.PaddlePSInstance(1, 2) if fleet is not None:
self.dataset.set_trainer_num(instance.get_worker_num()) fleet.fleet_instance.role_maker_.barrier_worker()
trainer_num = fleet.worker_num()
self.dataset.set_trainer_num(trainer_num)
self.dataset.global_shuffle() self.dataset.global_shuffle()
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
class QueueDataset(DatasetBase): class QueueDataset(DatasetBase):
...@@ -130,5 +137,5 @@ class QueueDataset(DatasetBase): ...@@ -130,5 +137,5 @@ class QueueDataset(DatasetBase):
def local_shuffle(self): def local_shuffle(self):
pass pass
def global_shuffle(self): def global_shuffle(self, fleet=None):
pass pass
...@@ -170,7 +170,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -170,7 +170,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
""" """
if self._check_role_generation(): if self._check_role_generation():
if self.is_worker(): if self.is_worker():
return self.get_size() return self.get_size() / 2;
return 0 return 0
def server_num(self): def server_num(self):
...@@ -179,7 +179,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -179,7 +179,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
""" """
if self._check_role_generation(): if self._check_role_generation():
if self.is_server(): if self.is_server():
return self.get_size() return self.get_size() / 2;
return 0 return 0
def worker_index(self): def worker_index(self):
......
...@@ -43,7 +43,7 @@ class Fleet(object): ...@@ -43,7 +43,7 @@ class Fleet(object):
save_pserver_model(): save model parameters in pserver, called from a server node save_pserver_model(): save model parameters in pserver, called from a server node
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.fluid.incubate.fleet.parameter_server as fleet import paddle.fluid.incubate.fleet.parameter_server as fleet
from my_model import bow_net from my_model import bow_net
...@@ -58,7 +58,7 @@ class Fleet(object): ...@@ -58,7 +58,7 @@ class Fleet(object):
fleet.init_worker() # init worker should be called before training fleet.init_worker() # init worker should be called before training
# do other things like training # do other things like training
elif fleet.is_server(): elif fleet.is_server():
fleet.init_pserver() fleet.init_pserver()
fleet.stop() fleet.stop()
""" """
...@@ -75,7 +75,7 @@ class Fleet(object): ...@@ -75,7 +75,7 @@ class Fleet(object):
""" """
init(): which should be called only once in user's python scripts. init() will initialize init(): which should be called only once in user's python scripts. init() will initialize
FleetWrapper in CPP, it will also initialize a RoleMaker which is used for identifying FleetWrapper in CPP, it will also initialize a RoleMaker which is used for identifying
current node's role, e.g. worker, server, etc. current node's role, e.g. worker, server, etc.
""" """
if not self.is_initialized_: if not self.is_initialized_:
self.role_maker_ = MPISymetricRoleMaker() self.role_maker_ = MPISymetricRoleMaker()
...@@ -122,7 +122,7 @@ class Fleet(object): ...@@ -122,7 +122,7 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
def init_worker(self): def init_worker(self, program):
""" """
init_worker(): will be called by user. When a user knows current process is_server(), he/she init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect should call init_worker() to initialize global information about worker and connect
...@@ -143,6 +143,19 @@ class Fleet(object): ...@@ -143,6 +143,19 @@ class Fleet(object):
self.role_maker_.get_rank()) self.role_maker_.get_rank())
self.role_maker_.barrier_all() self.role_maker_.barrier_all()
self.role_maker_.barrier_worker() self.role_maker_.barrier_worker()
if self.role_maker_.is_first_worker():
tables = self._dist_desc.trainer_param.dense_table._values
for i in range(0, len(tables)):
table = tables[i];
var_name_list = []
for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i])
#print "table id ", table.table_id
#print "var_name_list ", var_name_list
self._fleet_ptr.init_model(program.desc,
int(table.table_id),
var_name_list)
self.role_maker_.barrier_worker()
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册