未验证 提交 66d51206 编写于 作者: J jiaqi 提交者: GitHub

add save/load model, shrink table, cvm, config file & fix pull dense bug (#17118)

* add save/load model, shrink table, cvm, config file & fix pull dense bug
test=develop

* fix global shuffle bug, fix pull dense bug, fix release memeory bug, fix shrink error
add client flush, add get data size
test=develop

* fix global shuffle bug
test=develop

* fix global shuffle bug
test=develop

* fix code style
test=develop

* fix code style & modify pslib cmake
test=develop

* fix error of _role_maker
test=develop

* fix code style
test=develop

* fix code style
test=develop

* fix code style
test=develop

* fix code style
test=develop

* fix code style
test=develop

* fix windows compile error of fleet
test=develop

* fix global shuffle bug

* add comment
test=develop

* update pslib.cmake
test=develop

* fix fill sparse bug
test=develop

* fix push sparse bug
test=develop
上级 266444b8
...@@ -29,9 +29,9 @@ INCLUDE(ExternalProject) ...@@ -29,9 +29,9 @@ INCLUDE(ExternalProject)
SET(PSLIB_PROJECT "extern_pslib") SET(PSLIB_PROJECT "extern_pslib")
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL)) IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
MESSAGE(STATUS "use pre defined download url") MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_VER "0.1.0" CACHE STRING "" FORCE) SET(PSLIB_VER "0.1.1" CACHE STRING "" FORCE)
SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE) SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE) SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/ps/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF() ENDIF()
MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}") MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}")
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib") SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
......
...@@ -95,6 +95,11 @@ class BlockingQueue { ...@@ -95,6 +95,11 @@ class BlockingQueue {
return q_.size(); return q_.size();
} }
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
std::deque<T>().swap(q_);
}
private: private:
std::mutex mutex_; std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
......
...@@ -158,6 +158,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -158,6 +158,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
mutex_for_update_memory_data_ = nullptr; mutex_for_update_memory_data_ = nullptr;
this->file_idx_ = nullptr; this->file_idx_ = nullptr;
this->mutex_for_pick_file_ = nullptr; this->mutex_for_pick_file_ = nullptr;
fleet_send_sleep_seconds_ = 2;
} }
template <typename T> template <typename T>
...@@ -366,7 +367,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -366,7 +367,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::vector<T*>> send_vec(trainer_num_); std::vector<std::vector<T*>> send_vec(trainer_num_);
std::vector<int> send_index(trainer_num_); std::vector<int> send_index(trainer_num_);
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_; uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_ + 1;
for (auto& vec : send_vec) { for (auto& vec : send_vec) {
vec.reserve(reserve_len); vec.reserve(reserve_len);
} }
...@@ -377,47 +378,34 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -377,47 +378,34 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto interval = GetMemoryDataInterval(); auto interval = GetMemoryDataInterval();
VLOG(3) << "global shuffle data from [" << interval.first << ", " VLOG(3) << "global shuffle data from [" << interval.first << ", "
<< interval.second << "), thread_id=" << thread_id_; << interval.second << "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) {
// if get ins id, can also use hash for (int64_t i = interval.first; i < interval.second;
// std::string ins_id = memory_data_[i].ins_id; i += fleet_send_batch_size_) {
int64_t random_num = rand_r(&rand_seed); for (int64_t j = 0; j < fleet_send_batch_size_ && i + j < interval.second;
++j) {
int64_t random_num = fleet_ptr->LocalRandomEngine()();
int64_t node_id = random_num % trainer_num_; int64_t node_id = random_num % trainer_num_;
send_vec[node_id].push_back(&((*memory_data_)[i])); send_vec[node_id].push_back(&((*memory_data_)[i + j]));
if (i % fleet_send_batch_size_ == 0 && i != 0) {
// shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
<< ", ins num=" << send_vec[j].size() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
send_vec[j].clear();
total_status.push_back(std::move(ret));
}
}
} }
// shuffle the sequence of sending to avoid network timeout error total_status.clear();
std::random_shuffle(send_index.begin(), send_index.end()); std::shuffle(send_index.begin(), send_index.end(),
fleet_ptr->LocalRandomEngine());
for (int index = 0; index < send_index.size(); ++index) { for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index]; int j = send_index[index];
if (send_vec[j].size() != 0) { if (send_vec[j].size() == 0) {
continue;
}
std::string send_str; std::string send_str;
SerializeIns(send_vec[j], &send_str); SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str); auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
total_status.push_back(std::move(ret)); total_status.push_back(std::move(ret));
} send_vec[j].clear();
std::vector<T*>().swap(send_vec[j]);
} }
for (auto& t : total_status) { for (auto& t : total_status) {
t.wait(); t.wait();
} }
sleep(fleet_send_sleep_seconds_);
}
VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_; VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_;
#endif #endif
} }
...@@ -436,6 +424,24 @@ std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() { ...@@ -436,6 +424,24 @@ std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() {
return std::make_pair(start, end); return std::make_pair(start, end);
} }
template <typename T>
int64_t InMemoryDataFeed<T>::GetChannelDataSize() {
if (cur_channel_ == 0) {
return shuffled_ins_->Size();
} else {
return shuffled_ins_out_->Size();
}
}
template <typename T>
void InMemoryDataFeed<T>::ReleaseChannelData() {
if (cur_channel_ == 0) {
shuffled_ins_->Clear();
} else {
shuffled_ins_out_->Clear();
}
}
// explicit instantiation // explicit instantiation
template class InMemoryDataFeed<std::vector<MultiSlotType>>; template class InMemoryDataFeed<std::vector<MultiSlotType>>;
......
...@@ -115,6 +115,9 @@ class DataFeed { ...@@ -115,6 +115,9 @@ class DataFeed {
virtual void FillChannelToMemoryData() {} virtual void FillChannelToMemoryData() {}
// This function will do nothing at default // This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) {} virtual void PutInsToChannel(const std::string& ins_str) {}
virtual int64_t GetChannelDataSize() { return 0; }
// This function will do nothing at default
virtual void ReleaseChannelData() {}
protected: protected:
// The following three functions are used to check if it is executed in this // The following three functions are used to check if it is executed in this
...@@ -224,6 +227,8 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -224,6 +227,8 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual int64_t GetChannelDataSize();
virtual void ReleaseChannelData();
protected: protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
...@@ -248,6 +253,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -248,6 +253,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_; std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_; std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_; int64_t fleet_send_batch_size_;
// sleep after send is to slow down sending data, but it's trick,
// should be removed later.
int64_t fleet_send_sleep_seconds_;
}; };
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed // This class define the data type of instance(ins_vec) in MultiSlotDataFeed
......
...@@ -141,6 +141,9 @@ template <typename T> ...@@ -141,6 +141,9 @@ template <typename T>
void DatasetImpl<T>::ReleaseMemory() { void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin"; VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_); std::vector<T>().swap(memory_data_);
for (int i = 0; i < readers_.size(); ++i) {
readers_[i]->ReleaseChannelData();
}
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end"; VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
} }
...@@ -178,8 +181,10 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -178,8 +181,10 @@ void DatasetImpl<T>::GlobalShuffle() {
if (readers_.size() == 0) { if (readers_.size() == 0) {
CreateReaders(); CreateReaders();
} }
// if it is not InMemory, memory_data_ is empty auto fleet_ptr = FleetWrapper::GetInstance();
std::random_shuffle(memory_data_.begin(), memory_data_.end()); // local shuffle all data before global shuffle
std::shuffle(memory_data_.begin(), memory_data_.end(),
fleet_ptr->LocalRandomEngine());
VLOG(3) << "start global shuffle threads"; VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads; std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
...@@ -260,6 +265,20 @@ void DatasetImpl<T>::DestroyReaders() { ...@@ -260,6 +265,20 @@ void DatasetImpl<T>::DestroyReaders() {
} }
} }
template <typename T>
int64_t DatasetImpl<T>::GetMemoryDataSize() {
return memory_data_.size();
}
template <typename T>
int64_t DatasetImpl<T>::GetShuffleDataSize() {
int64_t sum = 0;
for (int i = 0; i < readers_.size(); ++i) {
sum += readers_[i]->GetChannelDataSize();
}
return sum;
}
template <typename T> template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) { const std::string& msg) {
...@@ -267,7 +286,7 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, ...@@ -267,7 +286,7 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << msg.length(); << ", client_id=" << client_id << ", msg length=" << msg.length();
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = rand_r(&rand_seed) % thread_num_; int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_;
VLOG(3) << "ramdom index=" << index; VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg); readers_[index]->PutInsToChannel(msg);
#endif #endif
......
...@@ -85,6 +85,10 @@ class Dataset { ...@@ -85,6 +85,10 @@ class Dataset {
virtual void CreateReaders() = 0; virtual void CreateReaders() = 0;
// destroy readers // destroy readers
virtual void DestroyReaders() = 0; virtual void DestroyReaders() = 0;
// get memory data size
virtual int64_t GetMemoryDataSize() = 0;
// get shuffle data size
virtual int64_t GetShuffleDataSize() = 0;
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
...@@ -127,6 +131,8 @@ class DatasetImpl : public Dataset { ...@@ -127,6 +131,8 @@ class DatasetImpl : public Dataset {
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual void CreateReaders(); virtual void CreateReaders();
virtual void DestroyReaders(); virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
virtual int64_t GetShuffleDataSize();
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
......
...@@ -48,6 +48,7 @@ class PullDenseWorker { ...@@ -48,6 +48,7 @@ class PullDenseWorker {
void IncreaseThreadVersion(int thread_id, uint64_t table_id); void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id); void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec); void Wait(std::vector<::std::future<int32_t>>* status_vec);
void PullDense(bool force_update = false);
static std::shared_ptr<PullDenseWorker> GetInstance() { static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) { if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker()); s_instance_.reset(new paddle::framework::PullDenseWorker());
...@@ -92,7 +93,7 @@ class PullDenseWorker { ...@@ -92,7 +93,7 @@ class PullDenseWorker {
// should incorporate different type of device // should incorporate different type of device
class DeviceWorker { class DeviceWorker {
public: public:
DeviceWorker() {} DeviceWorker() { use_cvm_ = false; }
virtual ~DeviceWorker() {} virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0; virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0; virtual void SetDeviceIndex(int tid) = 0;
...@@ -114,6 +115,7 @@ class DeviceWorker { ...@@ -114,6 +115,7 @@ class DeviceWorker {
std::shared_ptr<DataFeed> device_reader_; std::shared_ptr<DataFeed> device_reader_;
int64_t batch_num_; int64_t batch_num_;
FetchConfig fetch_config_; FetchConfig fetch_config_;
bool use_cvm_;
}; };
class CPUWorkerBase : public DeviceWorker { class CPUWorkerBase : public DeviceWorker {
......
...@@ -63,6 +63,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -63,6 +63,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
fleet_ptr_ = FleetWrapper::GetInstance(); fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
use_cvm_ = desc.use_cvm();
} }
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
...@@ -139,6 +140,16 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -139,6 +140,16 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
LoD data_lod{tensor_lod}; LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod); tensor_emb->set_lod(data_lod);
for (int index = 0; index < len; ++index) { for (int index = 0; index < len; ++index) {
if (use_cvm_) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data(),
sizeof(float) * table.emb_dim());
continue;
}
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data(),
sizeof(float) * table.emb_dim());
fea_idx++;
} else {
if (ids[index] == 0u) { if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2, memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
sizeof(float) * table.emb_dim()); sizeof(float) * table.emb_dim());
...@@ -149,6 +160,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -149,6 +160,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
fea_idx++; fea_idx++;
} }
} }
}
} }
void DownpourWorker::TrainFilesWithProfiler() { void DownpourWorker::TrainFilesWithProfiler() {
...@@ -197,9 +209,9 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -197,9 +209,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
uint64_t tid = static_cast<uint64_t>( uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i)); param_.program_config(0).pull_sparse_table_id(i));
TableParameter table; TableParameter table;
for (auto i : param_.sparse_table()) { for (auto j : param_.sparse_table()) {
if (i.table_id() == tid) { if (j.table_id() == tid) {
table = i; table = j;
break; break;
} }
} }
...@@ -259,7 +271,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -259,7 +271,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
fleet_ptr_->PushSparseVarsWithLabelAsync( fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_); &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_);
timeline.Pause(); timeline.Pause();
push_sparse_time += timeline.ElapsedSec(); push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
...@@ -367,9 +379,9 @@ void DownpourWorker::TrainFiles() { ...@@ -367,9 +379,9 @@ void DownpourWorker::TrainFiles() {
uint64_t tid = static_cast<uint64_t>( uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i)); param_.program_config(0).pull_sparse_table_id(i));
TableParameter table; TableParameter table;
for (auto i : param_.sparse_table()) { for (auto j : param_.sparse_table()) {
if (i.table_id() == tid) { if (j.table_id() == tid) {
table = i; table = j;
break; break;
} }
} }
...@@ -411,7 +423,7 @@ void DownpourWorker::TrainFiles() { ...@@ -411,7 +423,7 @@ void DownpourWorker::TrainFiles() {
fleet_ptr_->PushSparseVarsWithLabelAsync( fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_); &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_);
} }
} }
......
...@@ -281,9 +281,16 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -281,9 +281,16 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
const std::vector<std::string>& sparse_key_names, const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status) { std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
int offset = 2; int offset = 2;
int grad_dim = emb_dim;
if (use_cvm) {
offset = 0;
grad_dim = emb_dim - 2;
}
CHECK_GE(grad_dim, 0);
uint64_t fea_idx = 0u; uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) { for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_grad_names[i]); Variable* g_var = scope.FindVar(sparse_grad_names[i]);
...@@ -307,7 +314,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -307,7 +314,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
for (auto& t : *push_values) { for (auto& t : *push_values) {
t.resize(emb_dim + offset); t.resize(emb_dim + offset);
} }
if (scale_sparse_gradient_with_batch_size_ && grad_dim > 0) {
int dim = emb_dim + offset;
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g, g_tensor->numel() / dim, dim);
g_mat.rightCols(grad_dim) *= batch_size;
}
for (auto id_idx = 0u; id_idx < len; ++id_idx) { for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) { if (ids[id_idx] == 0) {
g += emb_dim; g += emb_dim;
...@@ -315,10 +328,15 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -315,10 +328,15 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
} }
CHECK(fea_idx < (*push_values).size()); CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size()); CHECK(fea_idx < fea_labels.size());
if (use_cvm) {
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
} else {
memcpy((*push_values)[fea_idx].data() + offset, g, memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
(*push_values)[fea_idx][0] = 1.0f; (*push_values)[fea_idx][0] = 1.0f;
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]); (*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
}
g += emb_dim; g += emb_dim;
fea_idx++; fea_idx++;
} }
...@@ -337,6 +355,89 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -337,6 +355,89 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif #endif
} }
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->load(path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::LoadModel does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "save model failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::SaveModel does nothing when no pslib";
#endif
}
void FleetWrapper::ShrinkSparseTable(int table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->shrink(table_id);
ret.wait();
#else
VLOG(0) << "FleetWrapper::ShrinkSparseTable does nothing when no pslib";
#endif
}
void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list,
float decay) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
for (std::string& name : var_list) {
if (name.find("batch_sum") != std::string::npos) {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
VLOG(3) << "prepare shrink dense batch_sum";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
Eigen::Map<Eigen::MatrixXf> mat(g, 1, tensor->numel());
mat *= decay;
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
} else {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
}
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push shrink dense param failed, status[" << status << "]";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::ShrinkSparseTable does nothing when no pslib";
#endif
}
void FleetWrapper::ClientFlush() {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->flush();
ret.wait();
#else
VLOG(0) << "FleetWrapper::ServerFlush does nothing when no pslib";
#endif
}
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) { MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
...@@ -398,6 +499,24 @@ void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) { ...@@ -398,6 +499,24 @@ void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#endif #endif
} }
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
#ifdef PADDLE_WITH_PSLIB
engine_wrapper_t() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
engine.seed(sseq);
}
#endif
};
thread_local engine_wrapper_t r;
return r.engine;
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>( template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<std::vector<MultiSlotType>*>&, std::string*); const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>( template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
......
...@@ -55,7 +55,7 @@ namespace framework { ...@@ -55,7 +55,7 @@ namespace framework {
class FleetWrapper { class FleetWrapper {
public: public:
virtual ~FleetWrapper() {} virtual ~FleetWrapper() {}
FleetWrapper() {} FleetWrapper() { scale_sparse_gradient_with_batch_size_ = true; }
// Pull sparse variables from server in Sync mode // Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys // Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values // Param<out>: fea_values
...@@ -99,7 +99,8 @@ class FleetWrapper { ...@@ -99,7 +99,8 @@ class FleetWrapper {
const std::vector<std::string>& sparse_key_names, const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status); std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm);
// Push sparse variables to server in Async mode // Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names // Param<In>: scope, table_id, fea_keys, sparse_grad_names
...@@ -128,6 +129,19 @@ class FleetWrapper { ...@@ -128,6 +129,19 @@ class FleetWrapper {
// create client to client connection // create client to client connection
void CreateClient2ClientConnection(); void CreateClient2ClientConnection();
// flush all push requests
void ClientFlush();
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
void ShrinkSparseTable(int table_id);
void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay);
// register client to client communication // register client to client communication
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc; typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
...@@ -146,6 +160,9 @@ class FleetWrapper { ...@@ -146,6 +160,9 @@ class FleetWrapper {
return s_instance_; return s_instance_;
} }
// this performs better than rand_r, especially large data
std::default_random_engine& LocalRandomEngine();
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_; static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif #endif
...@@ -158,6 +175,7 @@ class FleetWrapper { ...@@ -158,6 +175,7 @@ class FleetWrapper {
protected: protected:
static bool is_initialized_; static bool is_initialized_;
bool scale_sparse_gradient_with_batch_size_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper); DISABLE_COPY_AND_ASSIGN(FleetWrapper);
}; };
......
...@@ -27,6 +27,7 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) { ...@@ -27,6 +27,7 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) {
for (size_t i = 0; i < param_.skip_ops_size(); ++i) { for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i); skip_ops_[i] = param_.skip_ops(i);
} }
use_cvm_ = desc.use_cvm();
} }
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) { void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
......
...@@ -83,28 +83,34 @@ void PullDenseWorker::Stop() { ...@@ -83,28 +83,34 @@ void PullDenseWorker::Stop() {
} }
} }
int PullDenseWorker::Start() { void PullDenseWorker::PullDense(bool force_update) {
running_ = true;
t_ = std::thread(&PullDenseWorker::Run, this);
return 0;
}
void PullDenseWorker::Run() {
while (running_) {
pull_dense_status_.resize(0); pull_dense_status_.resize(0);
for (size_t i = 0; for (size_t i = 0;
i < dwp_param_.program_config(0).pull_dense_table_id_size(); ++i) { i < dwp_param_.program_config(0).pull_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>( uint64_t tid = static_cast<uint64_t>(
dwp_param_.program_config(0).pull_dense_table_id(i)); dwp_param_.program_config(0).pull_dense_table_id(i));
if (CheckUpdateParam(tid)) { if (force_update || CheckUpdateParam(tid)) {
fleet_ptr_->PullDenseVarsAsync( fleet_ptr_->PullDenseVarsAsync(*root_scope_, tid, dense_value_names_[tid],
*root_scope_, tid, dense_value_names_[tid], &pull_dense_status_); &pull_dense_status_);
ResetThreadVersion(tid); ResetThreadVersion(tid);
} }
} }
if (pull_dense_status_.size() != 0) { if (pull_dense_status_.size() != 0) {
Wait(&pull_dense_status_); Wait(&pull_dense_status_);
} }
}
int PullDenseWorker::Start() {
running_ = true;
// before training, we can pull dense from pserver first.
PullDense(true);
t_ = std::thread(&PullDenseWorker::Run, this);
return 0;
}
void PullDenseWorker::Run() {
while (running_) {
PullDense(false);
#ifndef _WIN32 #ifndef _WIN32
usleep(sleep_time_ms_ * 1000); usleep(sleep_time_ms_ * 1000);
#endif #endif
......
...@@ -30,6 +30,7 @@ message TrainerDesc { ...@@ -30,6 +30,7 @@ message TrainerDesc {
repeated string filelist = 5; repeated string filelist = 5;
optional bool debug = 6 [ default = false ]; optional bool debug = 6 [ default = false ];
optional FetchConfig fetch_config = 7; optional FetchConfig fetch_config = 7;
optional bool use_cvm = 8 [ default = false ];
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
......
...@@ -66,7 +66,9 @@ void BindDataset(py::module* m) { ...@@ -66,7 +66,9 @@ void BindDataset(py::module* m) {
.def("load_into_memory", &framework::Dataset::LoadIntoMemory) .def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("release_memory", &framework::Dataset::ReleaseMemory) .def("release_memory", &framework::Dataset::ReleaseMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle) .def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle); .def("global_shuffle", &framework::Dataset::GlobalShuffle)
.def("get_memory_data_size", &framework::Dataset::GetMemoryDataSize)
.def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize);
} }
} // end namespace pybind } // end namespace pybind
......
...@@ -47,12 +47,17 @@ void BindFleetWrapper(py::module* m) { ...@@ -47,12 +47,17 @@ void BindFleetWrapper(py::module* m) {
.def("run_server", &framework::FleetWrapper::RunServer) .def("run_server", &framework::FleetWrapper::RunServer)
.def("init_worker", &framework::FleetWrapper::InitWorker) .def("init_worker", &framework::FleetWrapper::InitWorker)
.def("init_model", &framework::FleetWrapper::PushDenseParamSync) .def("init_model", &framework::FleetWrapper::PushDenseParamSync)
.def("save_model", &framework::FleetWrapper::SaveModel)
.def("load_model", &framework::FleetWrapper::LoadModel)
.def("stop_server", &framework::FleetWrapper::StopServer) .def("stop_server", &framework::FleetWrapper::StopServer)
.def("gather_servers", &framework::FleetWrapper::GatherServers) .def("gather_servers", &framework::FleetWrapper::GatherServers)
.def("gather_clients", &framework::FleetWrapper::GatherClients) .def("gather_clients", &framework::FleetWrapper::GatherClients)
.def("get_clients_info", &framework::FleetWrapper::GetClientsInfo) .def("get_clients_info", &framework::FleetWrapper::GetClientsInfo)
.def("create_client2client_connection", .def("create_client2client_connection",
&framework::FleetWrapper::CreateClient2ClientConnection); &framework::FleetWrapper::CreateClient2ClientConnection)
.def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable)
.def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable)
.def("client_flush", &framework::FleetWrapper::ClientFlush);
} // end FleetWrapper } // end FleetWrapper
} // end namespace pybind } // end namespace pybind
} // end namespace paddle } // end namespace paddle
...@@ -29,9 +29,7 @@ class DatasetFactory(object): ...@@ -29,9 +29,7 @@ class DatasetFactory(object):
""" """
def __init__(self): def __init__(self):
""" """ Init. """
Init
"""
pass pass
def create_dataset(self, datafeed_class="QueueDataset"): def create_dataset(self, datafeed_class="QueueDataset"):
...@@ -39,6 +37,10 @@ class DatasetFactory(object): ...@@ -39,6 +37,10 @@ class DatasetFactory(object):
Create "QueueDataset" or "InMemoryDataset", Create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset". the default is "QueueDataset".
Args:
datafeed_class(str): datafeed class name, QueueDataset or InMemoryDataset.
Default is QueueDataset.
Examples: Examples:
import paddle.fluid as fluid import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset() dataset = fluid.DatasetFactory().create_dataset()
...@@ -52,14 +54,10 @@ class DatasetFactory(object): ...@@ -52,14 +54,10 @@ class DatasetFactory(object):
class DatasetBase(object): class DatasetBase(object):
""" """ Base dataset class. """
Base dataset class
"""
def __init__(self): def __init__(self):
""" """ Init. """
Init
"""
# define class name here # define class name here
# to decide whether we need create in memory instance # to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
...@@ -76,7 +74,7 @@ class DatasetBase(object): ...@@ -76,7 +74,7 @@ class DatasetBase(object):
>>> dataset.set_pipe_command("python my_script.py") >>> dataset.set_pipe_command("python my_script.py")
Args: Args:
pipe_command: pipe command pipe_command(str): pipe command
""" """
self.proto_desc.pipe_command = pipe_command self.proto_desc.pipe_command = pipe_command
...@@ -89,7 +87,7 @@ class DatasetBase(object): ...@@ -89,7 +87,7 @@ class DatasetBase(object):
>>> dataset.set_batch_size(128) >>> dataset.set_batch_size(128)
Args: Args:
batch_size: batch size batch_size(int): batch size
""" """
self.proto_desc.batch_size = batch_size self.proto_desc.batch_size = batch_size
...@@ -102,7 +100,7 @@ class DatasetBase(object): ...@@ -102,7 +100,7 @@ class DatasetBase(object):
>>> dataset.set_thread(12) >>> dataset.set_thread(12)
Args: Args:
thread_num: thread num thread_num(int): thread num
""" """
self.dataset.set_thread_num(thread_num) self.dataset.set_thread_num(thread_num)
self.thread_num = thread_num self.thread_num = thread_num
...@@ -115,7 +113,7 @@ class DatasetBase(object): ...@@ -115,7 +113,7 @@ class DatasetBase(object):
>>> dataset.set_filelist(['a.txt', 'b.txt']) >>> dataset.set_filelist(['a.txt', 'b.txt'])
Args: Args:
filelist: file list filelist(list): file list
""" """
self.dataset.set_filelist(filelist) self.dataset.set_filelist(filelist)
...@@ -127,7 +125,7 @@ class DatasetBase(object): ...@@ -127,7 +125,7 @@ class DatasetBase(object):
>>> dataset.set_use_var([data, label]) >>> dataset.set_use_var([data, label])
Args: Args:
var_list: variable list var_list(list): variable list
""" """
multi_slot = self.proto_desc.multi_slot_desc multi_slot = self.proto_desc.multi_slot_desc
for var in var_list: for var in var_list:
...@@ -154,8 +152,8 @@ class DatasetBase(object): ...@@ -154,8 +152,8 @@ class DatasetBase(object):
>>> dataset.set_hdfs_config("my_fs_name", "my_fs_ugi") >>> dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
Args: Args:
fs_name: fs name fs_name(str): fs name
fs_ugi: fs ugi fs_ugi(str): fs ugi
""" """
self.dataset.set_hdfs_config(fs_name, fs_ugi) self.dataset.set_hdfs_config(fs_name, fs_ugi)
...@@ -190,9 +188,7 @@ class InMemoryDataset(DatasetBase): ...@@ -190,9 +188,7 @@ class InMemoryDataset(DatasetBase):
""" """
def __init__(self): def __init__(self):
""" """ Init. """
Init
"""
super(InMemoryDataset, self).__init__() super(InMemoryDataset, self).__init__()
self.proto_desc.name = "MultiSlotInMemoryDataFeed" self.proto_desc.name = "MultiSlotInMemoryDataFeed"
...@@ -233,7 +229,7 @@ class InMemoryDataset(DatasetBase): ...@@ -233,7 +229,7 @@ class InMemoryDataset(DatasetBase):
Examples: Examples:
>>> import paddle.fluid as fluid >>> import paddle.fluid as fluid
>>> from paddle.fluid.incubate.fleet.pslib import fleet >>> from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset") >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"] >>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist) >>> dataset.set_filelist(filelist)
...@@ -241,21 +237,22 @@ class InMemoryDataset(DatasetBase): ...@@ -241,21 +237,22 @@ class InMemoryDataset(DatasetBase):
>>> dataset.global_shuffle(fleet) >>> dataset.global_shuffle(fleet)
Args: Args:
fleet: fleet singleton. Default None. fleet(Fleet): fleet singleton. Default None.
""" """
trainer_num = 1 trainer_num = 1
fleet_send_batch_size = 80000 fleet_send_batch_size = 80000
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker() fleet._role_maker._barrier_worker()
trainer_num = fleet.worker_num() trainer_num = fleet.worker_num()
self.dataset.register_client2client_msg_handler() self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num) self.dataset.set_trainer_num(trainer_num)
self.dataset.set_fleet_send_batch_size(fleet_send_batch_size) self.dataset.set_fleet_send_batch_size(fleet_send_batch_size)
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker() fleet._role_maker._barrier_worker()
self.dataset.global_shuffle() self.dataset.global_shuffle()
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker() fleet._role_maker._barrier_worker()
def release_memory(self): def release_memory(self):
""" """
...@@ -263,7 +260,7 @@ class InMemoryDataset(DatasetBase): ...@@ -263,7 +260,7 @@ class InMemoryDataset(DatasetBase):
Example: Example:
>>> import paddle.fluid as fluid >>> import paddle.fluid as fluid
>>> import paddle.fluid.incubate.fleet.parameter_server as fleet >>> from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset") >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"] >>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist) >>> dataset.set_filelist(filelist)
...@@ -276,6 +273,76 @@ class InMemoryDataset(DatasetBase): ...@@ -276,6 +273,76 @@ class InMemoryDataset(DatasetBase):
""" """
self.dataset.release_memory() self.dataset.release_memory()
def get_memory_data_size(self, fleet=None):
"""
Get memory data size, user can call this function to know the num
of ins in all workers after load into memory.
Note:
This function may cause bad performance, because it has barrier
Args:
fleet(Fleet): Fleet Object.
Returns:
The size of memory data.
Example:
>>> import paddle.fluid as fluid
>>> from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist)
>>> dataset.load_into_memory()
>>> print dataset.get_memory_data_size(fleet)
"""
import numpy as np
local_data_size = self.dataset.get_memory_data_size()
local_data_size = np.array([local_data_size])
if fleet is not None:
global_data_size = local_data_size * 0
fleet._role_maker._node_type_comm.Allreduce(local_data_size,
global_data_size)
return global_data_size[0]
return local_data_size[0]
def get_shuffle_data_size(self, fleet=None):
"""
Get shuffle data size, user can call this function to know the num
of ins in all workers after local/global shuffle.
Note:
This function may cause bad performance to local shuffle,
because it has barrier. It does not affect global shuffle.
Args:
fleet(Fleet): Fleet Object.
Returns:
The size of shuffle data.
Example:
>>> import paddle.fluid as fluid
>>> from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist)
>>> dataset.load_into_memory()
>>> dataset.global_shuffle(fleet)
>>> print dataset.get_shuffle_data_size(fleet)
"""
import numpy as np
local_data_size = self.dataset.get_shuffle_data_size()
local_data_size = np.array([local_data_size])
if fleet is not None:
global_data_size = local_data_size * 0
fleet._role_maker._node_type_comm.Allreduce(local_data_size,
global_data_size)
return global_data_size[0]
return local_data_size[0]
class QueueDataset(DatasetBase): class QueueDataset(DatasetBase):
""" """
......
...@@ -155,6 +155,12 @@ class DownpourSGD(DeviceWorker): ...@@ -155,6 +155,12 @@ class DownpourSGD(DeviceWorker):
self._fleet_desc.trainer_param.sparse_table[0].slot_value) self._fleet_desc.trainer_param.sparse_table[0].slot_value)
sparse_table.sparse_grad_name.extend( sparse_table.sparse_grad_name.extend(
self._fleet_desc.trainer_param.sparse_table[0].slot_gradient) self._fleet_desc.trainer_param.sparse_table[0].slot_gradient)
if opt_info["use_cvm"]:
sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim
sparse_table.fea_dim = sparse_table.emb_dim
else:
sparse_table.emb_dim = \ sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2 0].accessor.fea_dim - 2
......
...@@ -822,7 +822,6 @@ class Executor(object): ...@@ -822,7 +822,6 @@ class Executor(object):
trainer._set_infer(True) trainer._set_infer(True)
trainer._gen_trainer_desc() trainer._gen_trainer_desc()
dataset._prepare_to_run() dataset._prepare_to_run()
if debug:
self._dump_debug_info(program=program, trainer=trainer) self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope, self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset, dataset.dataset,
...@@ -902,7 +901,6 @@ class Executor(object): ...@@ -902,7 +901,6 @@ class Executor(object):
print_period=print_period) print_period=print_period)
trainer._gen_trainer_desc() trainer._gen_trainer_desc()
dataset._prepare_to_run() dataset._prepare_to_run()
if debug:
self._dump_debug_info(program=program, trainer=trainer) self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope, self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset, dataset.dataset,
......
...@@ -61,6 +61,9 @@ class TrainerDesc(object): ...@@ -61,6 +61,9 @@ class TrainerDesc(object):
def _set_program(self, program): def _set_program(self, program):
self._program = program self._program = program
def _set_use_cvm(self, use_cvm=False):
self.proto_desc.use_cvm = use_cvm
def _desc(self): def _desc(self):
from google.protobuf import text_format from google.protobuf import text_format
return text_format.MessageToString(self.proto_desc) return text_format.MessageToString(self.proto_desc)
......
...@@ -38,4 +38,5 @@ class TrainerFactory(object): ...@@ -38,4 +38,5 @@ class TrainerFactory(object):
device_worker._set_fleet_desc(opt_info["fleet_desc"]) device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_device_worker(device_worker) trainer._set_device_worker(device_worker)
trainer._set_fleet_desc(opt_info["fleet_desc"]) trainer._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_use_cvm(opt_info["use_cvm"])
return trainer return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册