提交 3cea00bd 编写于 作者: X xujiaqi01 提交者: dongdaxiang

store memory data in Dataset && fix bug

上级 ff87698a
......@@ -68,8 +68,10 @@ void DataFeed::SetBatchSize(int batch_size) {
bool DataFeed::PickOneFile(std::string* filename) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
if (file_idx_ == filelist_.size()) {
VLOG(3) << "DataFeed::PickOneFile no more file to pick";
return false;
}
VLOG(3) << "file_idx_=" << file_idx_;
*filename = filelist_[file_idx_++];
// LOG(ERROR) << "pick file:" << *filename;
return true;
......@@ -146,17 +148,18 @@ template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_ = 0;
shuffled_ins_ = nullptr;
shuffled_ins_out_ = nullptr;
shuffled_ins_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
shuffled_ins_out_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
fleet_send_batch_size_ = 10000;
}
template <typename T>
bool InMemoryDataFeed<T>::Start() {
DataFeed::CheckSetFileList();
if (memory_data_.size() != 0) {
CHECK_EQ(cur_channel_, 0);
shuffled_ins_->Extend(std::move(memory_data_));
std::vector<T>().swap(memory_data_);
if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) {
FillMemoryDataToChannel();
//std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_);
//std::vector<T>().swap(memory_data_);
}
DataFeed::finish_start_ = true;
return true;
......@@ -196,6 +199,31 @@ int InMemoryDataFeed<T>::Next() {
return DataFeed::batch_size_;
}
template <typename T>
void InMemoryDataFeed<T>::SetMemoryData(void* memory_data) {
memory_data_ = static_cast<std::vector<T>*>(memory_data);
}
template <typename T>
void InMemoryDataFeed<T>::SetMemoryDataMutex(std::mutex* mutex) {
mutex_for_update_memory_data_ = mutex;
}
template <typename T>
void InMemoryDataFeed<T>::SetThreadId(int thread_id) {
thread_id_ = thread_id;
}
template <typename T>
void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
template <typename T>
void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins;
......@@ -203,11 +231,54 @@ void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
shuffled_ins_->Push(std::move(ins));
}
template <typename T>
void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
VLOG(3) << "InMemoryDataFeed<T>::FillMemoryDataToChannel, thread_id=" << thread_id_;
int64_t start = 0;
int64_t end = 0;
int64_t size = memory_data_->size();
VLOG(3) << "memory_data size=" << size;
for (int64_t i = 0; i <= static_cast<int64_t>(thread_id_); ++i) {
int64_t len = size / static_cast<int64_t>(thread_num_) +
(i < (size % static_cast<int64_t>(thread_num_)));
start = end;
end += len;
}
for (int64_t i = start; i < end; ++i) {
T& t = (*memory_data_)[i];
shuffled_ins_->Push(std::move(t));
}
}
template <typename T>
void InMemoryDataFeed<T>::FillChannelToMemoryData() {
VLOG(3) << "InMemoryDataFeed<T>::FillChannelToMemoryData, thread_id=" << thread_id_;
std::vector<T> local_vec;
std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr;
if (cur_channel_ == 0) {
channel = shuffled_ins_;
} else {
channel = shuffled_ins_out_;
}
CHECK(channel != nullptr);
local_vec.reserve(channel->Size());
for (int64_t i = 0; i < channel->Size(); ++i) {
channel->Pop(local_vec[i]);
}
std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_);
lock.lock();
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end());
lock.unlock();
std::vector<T>().swap(local_vec);
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() begin, thread_id=" << thread_id_;
std::vector<T> local_vec;
std::string filename;
while (DataFeed::PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename << ", thread_id=" << thread_id_;
int err_no = 0;
PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
......@@ -216,35 +287,50 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance);
}
memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end());
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() read all lines, thread_id=" << thread_id_;
{
std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_);
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end());
}
std::vector<T>().swap(local_vec);
}
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() end, thread_id=" << thread_id_;
}
template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() begin, thread_id=" << thread_id_;
FillMemoryDataToChannel();
VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() end, thread_id=" << thread_id_;
}
// todo global shuffle
/*
template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
for (int64_t i = 0; i < memory_data_.size(); ++i) {
void InMemoryDataFeed<T>::GlobalShuffle() {
auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::string> send_str_vec(trainer_num_);
for (int64_t i = 0; i < memory_data_->size(); ++i) {
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t hash_id = paddle::ps::local_random_engine()();
//int64_t hash_id = hash(ins_id);
//int64_t hash_id = paddle::ps::local_random_engine()();
int64_t hash_id = 0;
int64_t node_id = hash_id % trainer_num_;
std::string str;
SerializeIns(memory_data_[i], str);
auto fleet_ptr = FleetWrapper::GetInstance();
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str);
SerializeIns((*memory_data_)[i], str);
send_str_vec[node_id] += str;
if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_str_vec.size(); ++j) {
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]);
send_str_vec[j] = "";
}
}
}
for (int j = 0; j < send_str_vec.size(); ++j) {
if (send_str_vec[j].length() != 0) {
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]);
}
}
}
*/
// explicit instantiation
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
......@@ -646,6 +732,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(
if (getline(file_, line)) {
int use_slots_num = use_slots_.size();
instance->resize(use_slots_num);
VLOG(3) << line;
// parse line
const char* str = line.c_str();
char* endptr = const_cast<char*>(str);
......@@ -735,12 +822,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns(
const std::vector<MultiSlotType>& ins, std::string& str) {
return;
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Serialize(ins, str);
}
// todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>& ins,
const std::string& str) {
return;
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Deserialize(ins, str);
}
} // namespace framework
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <vector>
#include <sstream>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -78,17 +79,33 @@ class DataFeed {
// This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name);
// This function will do nothing at default
virtual void SetMemoryData(void* memory_data) { }
// This function will do nothing at default
virtual void SetMemoryDataMutex(std::mutex* mutex) { }
// This function will do nothing at default
virtual void SetThreadId(int thread_id) { }
// This function will do nothing at default
virtual void SetThreadNum(int thread_num) { }
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) { }
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
virtual void LocalShuffle() {
PADDLE_THROW("This function(LocalShuffle) is not implemented.");
}
virtual void GlobalShuffle(int trainer_num) {
virtual void GlobalShuffle() {
PADDLE_THROW("This function(GlobalShuffle) is not implemented.");
}
virtual void FillMemoryDataToChannel() {
PADDLE_THROW("This function(FillMemoryDataToChannel) is not implemented.");
}
virtual void FillChannelToMemoryData() {
PADDLE_THROW("This function(FillChannelToMemoryData) is not implemented.");
}
virtual void PutInsToChannel(const std::string& ins_str) {
PADDLE_THROW("This function(PutToChannel) is not implemented.");
PADDLE_THROW("This function(PutInsToChannel) is not implemented.");
}
protected:
......@@ -181,13 +198,20 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
InMemoryDataFeed();
virtual ~InMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
virtual void SetMemoryData(void* memory_data);
virtual void SetMemoryDataMutex(std::mutex* mutex);
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void PutInsToChannel(const std::string& ins_str);
virtual void FillMemoryDataToChannel();
virtual void FillChannelToMemoryData();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
// todo global shuffle
//virtual void GlobalShuffle(int trainer_num);
virtual void GlobalShuffle();
protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0;
......@@ -196,13 +220,18 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void SerializeIns(const T& ins, std::string& str) = 0;
virtual void DeserializeIns(T& ins, const std::string& str) = 0;
std::vector<T> memory_data_;
int thread_id_;
int thread_num_;
int trainer_num_;
std::vector<T>* memory_data_;
std::mutex* mutex_for_update_memory_data_;
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int cur_channel_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
......@@ -226,6 +255,7 @@ class MultiSlotType {
offset_[0] = 0;
}
const std::vector<size_t>& GetOffset() const { return offset_; }
std::vector<size_t>& MutableOffset() { return offset_; }
void AddValue(const float v) {
CheckFloat();
float_feasign_.push_back(v);
......@@ -248,8 +278,11 @@ class MultiSlotType {
}
}
const std::vector<float>& GetFloatData() const { return float_feasign_; }
std::vector<float>& MutableFloatData() { return float_feasign_; }
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
const std::string& GetType() const { return type_; }
std::string& MutableType() { return type_; }
private:
void CheckType(const std::string& type) const {
......
......@@ -12,6 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <random>
#include "paddle/fluid/framework/data_set.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
......@@ -21,23 +22,27 @@
namespace paddle {
namespace framework {
Dataset::Dataset() { thread_num_ = 1; }
template <typename T>
DatasetImpl<T>::DatasetImpl() { thread_num_ = 1; }
void Dataset::SetFileList(const std::vector<std::string>& filelist) {
template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist;
/*
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num_ = file_cnt;
}
}*/
}
// buggy here, a user should set filelist first before this function
// not user friendly
void Dataset::SetThreadNum(int thread_num) {
template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size();
if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num
......@@ -48,19 +53,24 @@ void Dataset::SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) {
template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc_);
}
const std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
Dataset::GetReaders() {
template <typename T>
std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
DatasetImpl<T>::GetReaders() {
return readers_;
}
void Dataset::LoadIntoMemory() {
template <typename T>
void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
if (readers_.size() == 0) {
CreateReaders();
}
......@@ -72,12 +82,18 @@ void Dataset::LoadIntoMemory() {
for (std::thread& t : load_threads) {
t.join();
}
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end";
}
void Dataset::LocalShuffle() {
template <typename T>
void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread(
......@@ -86,30 +102,37 @@ void Dataset::LocalShuffle() {
for (std::thread& t : local_shuffle_threads) {
t.join();
}
std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end";
}
// todo global shuffle
void Dataset::GlobalShuffle() {
/*
template <typename T>
void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->registe_client2client_msg_handler(0,
[this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> global_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::GlobalShuffle,
readers_[i].get(), trainer_num_));
for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(
std::thread(&paddle::framework::DataFeed::GlobalShuffle,
readers_[i].get()));
}
for (std::thread& t : global_shuffle_threads) {
t.join();
}*/
}
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end";
}
void Dataset::CreateReaders() {
template <typename T>
void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0";
VLOG(3) << "thread_num in Readers: " << thread_num_;
......@@ -118,22 +141,53 @@ void Dataset::CreateReaders() {
return;
}
VLOG(3) << "data feed class name: " << data_feed_desc_.name();
for (int64_t i = 0; i < thread_num_; ++i) {
for (int i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_);
readers_.back()->SetMemoryData(&memory_data_);
readers_.back()->SetMemoryDataMutex(&mutex_for_update_memory_data_);
readers_.back()->SetThreadId(i);
readers_.back()->SetThreadNum(thread_num_);
readers_.back()->SetTrainerNum(trainer_num_);
}
VLOG(3) << "Filelist size in readers: " << filelist_.size();
readers_[0]->SetFileList(filelist_);
}
int Dataset::ReceiveFromClient(int msg_type, int client_id,
template <typename T>
void DatasetImpl<T>::DestroyReaders() {
VLOG(3) << "Calling DestroyReaders()";
// clear memory_data_ before fill it
// because if LoadIntoMemory but no Shuffle,
// memory_data_ has empty data which has been std::move to channel
if (memory_data_.size() != 0) {
std::vector<T>().swap(memory_data_);
}
std::vector<std::thread> fill_threads;
for (int i = 0; i < thread_num_; ++i) {
fill_threads.push_back(std::thread(
&paddle::framework::DataFeed::FillChannelToMemoryData,
readers_[i].get()));
}
for (std::thread& t : fill_threads) {
t.join();
}
std::vector<std::string>().swap(filelist_);
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
}
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
// can also use hash
// todo random
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
int64_t index = 0;
readers_[index]->PutInsToChannel(msg);
return 0;
}
// explicit instantiation
template class DatasetImpl<std::vector<MultiSlotType>>;
} // end namespace framework
} // end namespace paddle
......@@ -28,8 +28,33 @@ namespace framework {
class Dataset {
public:
Dataset();
virtual ~Dataset() {}
Dataset() {};
virtual ~Dataset() {};
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
virtual void SetThreadNum(int thread_num) = 0;
virtual void SetTrainerNum(int trainer_num) = 0;
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
virtual const std::vector<std::string>& GetFileList() = 0;
virtual int GetThreadNum() = 0;
virtual int GetTrainerNum() = 0;
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders() = 0;
virtual void LoadIntoMemory() = 0;
virtual void LocalShuffle() = 0;
virtual void GlobalShuffle() = 0;
virtual void CreateReaders() = 0;
virtual void DestroyReaders() = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
};
template<typename T>
class DatasetImpl : public Dataset {
public:
DatasetImpl();
virtual ~DatasetImpl() {}
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
......@@ -43,25 +68,34 @@ class Dataset {
return data_feed_desc_;
}
virtual const std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders();
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
// todo global shuffle
virtual void GlobalShuffle();
virtual void CreateReaders();
virtual void DestroyReaders();
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<T> memory_data_;
std::mutex mutex_for_update_memory_data_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_vec_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_out_vec_;
int thread_num_;
std::string fs_name_;
std::string fs_ugi_;
paddle::framework::DataFeedDesc data_feed_desc_;
std::vector<std::string> filelist_;
int trainer_num_;
};
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
public:
MultiSlotDataset() {}
virtual ~MultiSlotDataset() {}
};
} // end namespace framework
} // end namespace paddle
......@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
......@@ -35,6 +36,30 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
#ifdef PADDLE_WITH_PSLIB
template<class AR>
paddle::ps::Archive<AR>& operator << (
paddle::ps::Archive<AR>& ar,
const MultiSlotType& ins) {
ar << ins.GetType();
ar << ins.GetOffset();
ar << ins.GetFloatData();
ar << ins.GetUint64Data();
return ar;
}
template<class AR>
paddle::ps::Archive<AR>& operator >> (
paddle::ps::Archive<AR>& ar,
MultiSlotType& ins) {
ar >> ins.MutableType();
ar >> ins.MutableOffset();
ar >> ins.MutableFloatData();
ar >> ins.MutableUint64Data();
return ar;
}
#endif
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
#endif
......@@ -266,5 +291,42 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
// todo registe_client2client_msg_handler
int FleetWrapper::registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler) {
return 0;
}
// todo send_client2client_msg
int FleetWrapper::send_client2client_msg(int msg_type, int to_client_id, const std::string& msg) {
return 0;
}
template<typename T>
void FleetWrapper::Serialize(const T& t, std::string& str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
ar << t;
str = std::string(ar.buffer(), ar.length());
#else
VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib";
#endif
}
template<typename T>
void FleetWrapper::Deserialize(T& t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
t = ar.get<T>();
#else
VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib";
#endif
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<MultiSlotType>&, std::string&);
template void FleetWrapper::Deserialize(
std::vector<MultiSlotType>&, const std::string&);
} // end namespace framework
} // end namespace paddle
......@@ -17,7 +17,11 @@ limitations under the License. */
#include <memory>
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#include <archive.h>
#endif
#include <random>
#include <atomic>
#include <time.h>
#include <string>
#include <vector>
#include "paddle/fluid/framework/scope.h"
......@@ -110,6 +114,16 @@ class FleetWrapper {
uint64_t RunServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc;
int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler);
int send_client2client_msg(int msg_type, int to_client_id, const std::string& msg);
std::default_random_engine& local_random_engine();
template<typename T>
void Serialize(const T& t, std::string& str);
template<typename T>
void Deserialize(T& t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
......
......@@ -41,17 +41,17 @@ namespace paddle {
namespace pybind {
void BindDataset(py::module* m) {
py::class_<framework::Dataset>(*m, "Dataset")
py::class_<framework::MultiSlotDataset>(*m, "MultiSlotDataset")
.def(py::init([]() {
return std::unique_ptr<framework::Dataset>(new framework::Dataset());
return std::unique_ptr<framework::MultiSlotDataset>(new framework::MultiSlotDataset());
}))
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle);
.def("set_filelist", &framework::MultiSlotDataset::SetFileList)
.def("set_thread_num", &framework::MultiSlotDataset::SetThreadNum)
.def("set_trainer_num", &framework::MultiSlotDataset::SetTrainerNum)
.def("set_data_feed_desc", &framework::MultiSlotDataset::SetDataFeedDesc)
.def("load_into_memory", &framework::MultiSlotDataset::LoadIntoMemory)
.def("local_shuffle", &framework::MultiSlotDataset::LocalShuffle)
.def("global_shuffle", &framework::MultiSlotDataset::GlobalShuffle);
}
} // end namespace pybind
......
......@@ -30,7 +30,7 @@ from .dataset import *
from . import async_executor
from .async_executor import *
from . import trainer
from . import trainer_desc
from . import inferencer
from . import io
......@@ -67,7 +67,7 @@ from . import install_check
Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \
trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \
trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + \
data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [
'io',
......
......@@ -37,7 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat"
self.dataset = core.Dataset()
self.dataset = core.MultiSlotDataset()
self.thread_num = 0
def set_pipe_command(self, pipe_command):
......@@ -109,7 +109,7 @@ class InMemoryDataset(DatasetBase):
self.proto_desc.name = "MultiSlotInMemoryDataFeed"
def load_into_memory(self):
_prepare_to_run()
self._prepare_to_run()
self.dataset.load_into_memory()
def local_shuffle(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册