未验证 提交 ca8c4f3e 编写于 作者: Z zhaocaibei123 提交者: GitHub

update dataset (#37194)

上级 54d2626a
...@@ -710,8 +710,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, ...@@ -710,8 +710,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) { MsgHandlerFunc handler) {
VLOG(1) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; VLOG(1) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
auto* communicator = Communicator::GetInstance(); auto* communicator = Communicator::GetInstance();
return communicator->_worker_ptr->registe_client2client_msg_handler(msg_type, // for unittest which does not call fleet.init_worker() first
handler); if (communicator == nullptr) {
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler communicator is "
"null";
return -1;
} else {
return communicator->_worker_ptr->registe_client2client_msg_handler(
msg_type, handler);
}
} }
std::future<int32_t> FleetWrapper::SendClientToClientMsg( std::future<int32_t> FleetWrapper::SendClientToClientMsg(
......
...@@ -368,20 +368,7 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { ...@@ -368,20 +368,7 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
VLOG(1) << "push dense param to table " << table_id VLOG(1) << "push dense param to table " << table_id
<< " from 0' trainer done"; << " from 0' trainer done";
} }
BarrierWithTable(1);
} else {
BarrierWithTable(1);
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param to table " << table_id
<< " from 0' trainer done";
} }
}
std::this_thread::sleep_for(
std::chrono::milliseconds(100 + trainer_id_ * 10));
BarrierWithTable(1);
return; return;
} }
......
...@@ -19,6 +19,10 @@ ...@@ -19,6 +19,10 @@
#include "paddle/fluid/platform/monitor.h" #include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#ifdef PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/fleet.h"
#endif
#if defined _WIN32 || defined __APPLE__ #if defined _WIN32 || defined __APPLE__
#else #else
#define _LINUX #define _LINUX
...@@ -208,13 +212,17 @@ void DatasetImpl<T>::CreateChannel() { ...@@ -208,13 +212,17 @@ void DatasetImpl<T>::CreateChannel() {
// if sent message between workers, should first call this function // if sent message between workers, should first call this function
template <typename T> template <typename T>
void DatasetImpl<T>::RegisterClientToClientMsgHandler() { void DatasetImpl<T>::RegisterClientToClientMsgHandler() {
auto fleet_ptr = FleetWrapper::GetInstance(); #ifdef PADDLE_WITH_PSCORE
VLOG(3) << "RegisterClientToClientMsgHandler"; auto fleet_ptr = distributed::FleetWrapper::GetInstance();
#else
auto fleet_ptr = framework::FleetWrapper::GetInstance();
#endif
VLOG(1) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler( fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int { 0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg); return this->ReceiveFromClient(msg_type, client_id, msg);
}); });
VLOG(3) << "RegisterClientToClientMsgHandler done"; VLOG(1) << "RegisterClientToClientMsgHandler done";
} }
static void compute_left_batch_num(const int ins_num, const int thread_num, static void compute_left_batch_num(const int ins_num, const int thread_num,
std::vector<std::pair<int, int>>* offset, std::vector<std::pair<int, int>>* offset,
...@@ -523,7 +531,7 @@ void DatasetImpl<T>::LocalShuffle() { ...@@ -523,7 +531,7 @@ void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, no data to shuffle"; VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, no data to shuffle";
return; return;
} }
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = framework::FleetWrapper::GetInstance();
input_channel_->Close(); input_channel_->Close();
std::vector<T> data; std::vector<T> data;
input_channel_->ReadAll(data); input_channel_->ReadAll(data);
...@@ -540,11 +548,14 @@ void DatasetImpl<T>::LocalShuffle() { ...@@ -540,11 +548,14 @@ void DatasetImpl<T>::LocalShuffle() {
} }
void MultiSlotDataset::GlobalShuffle(int thread_num) { void MultiSlotDataset::GlobalShuffle(int thread_num) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "MultiSlotDataset::GlobalShuffle() begin"; VLOG(3) << "MultiSlotDataset::GlobalShuffle() begin";
platform::Timer timeline; platform::Timer timeline;
timeline.Start(); timeline.Start();
auto fleet_ptr = FleetWrapper::GetInstance(); #ifdef PADDLE_WITH_PSCORE
auto fleet_ptr = distributed::FleetWrapper::GetInstance();
#else
auto fleet_ptr = framework::FleetWrapper::GetInstance();
#endif
if (!input_channel_ || input_channel_->Size() == 0) { if (!input_channel_ || input_channel_->Size() == 0) {
VLOG(3) << "MultiSlotDataset::GlobalShuffle() end, no data to shuffle"; VLOG(3) << "MultiSlotDataset::GlobalShuffle() end, no data to shuffle";
...@@ -576,7 +587,12 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) { ...@@ -576,7 +587,12 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
}; };
auto global_shuffle_func = [this, get_client_id]() { auto global_shuffle_func = [this, get_client_id]() {
auto fleet_ptr = FleetWrapper::GetInstance(); #ifdef PADDLE_WITH_PSCORE
auto fleet_ptr = distributed::FleetWrapper::GetInstance();
#else
auto fleet_ptr = framework::FleetWrapper::GetInstance();
#endif
// auto fleet_ptr = framework::FleetWrapper::GetInstance();
std::vector<Record> data; std::vector<Record> data;
while (this->input_channel_->Read(data)) { while (this->input_channel_->Read(data)) {
std::vector<paddle::framework::BinaryArchive> ars(this->trainer_num_); std::vector<paddle::framework::BinaryArchive> ars(this->trainer_num_);
...@@ -633,7 +649,6 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) { ...@@ -633,7 +649,6 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
timeline.Pause(); timeline.Pause();
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time=" VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds"; << timeline.ElapsedSec() << " seconds";
#endif
} }
template <typename T> template <typename T>
...@@ -936,7 +951,7 @@ int MultiSlotDataset::ReceiveFromClient(int msg_type, int client_id, ...@@ -936,7 +951,7 @@ int MultiSlotDataset::ReceiveFromClient(int msg_type, int client_id,
} }
CHECK(ar.Cursor() == ar.Finish()); CHECK(ar.Cursor() == ar.Finish());
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = framework::FleetWrapper::GetInstance();
// not use random because it doesn't perform well here. // not use random because it doesn't perform well here.
// to make sure each channel get data equally, we just put data to // to make sure each channel get data equally, we just put data to
// channel one by one. // channel one by one.
...@@ -976,7 +991,7 @@ void MultiSlotDataset::DynamicAdjustReadersNum(int thread_num) { ...@@ -976,7 +991,7 @@ void MultiSlotDataset::DynamicAdjustReadersNum(int thread_num) {
void MultiSlotDataset::PostprocessInstance() { void MultiSlotDataset::PostprocessInstance() {
// divide pv instance, and merge to input_channel_ // divide pv instance, and merge to input_channel_
if (enable_pv_merge_) { if (enable_pv_merge_) {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = framework::FleetWrapper::GetInstance();
std::shuffle(input_records_.begin(), input_records_.end(), std::shuffle(input_records_.begin(), input_records_.end(),
fleet_ptr->LocalRandomEngine()); fleet_ptr->LocalRandomEngine());
input_channel_->Open(); input_channel_->Open();
...@@ -1014,7 +1029,7 @@ void MultiSlotDataset::PreprocessInstance() { ...@@ -1014,7 +1029,7 @@ void MultiSlotDataset::PreprocessInstance() {
if (!enable_pv_merge_) { // means to use Record if (!enable_pv_merge_) { // means to use Record
this->LocalShuffle(); this->LocalShuffle();
} else { // means to use Pv } else { // means to use Pv
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = framework::FleetWrapper::GetInstance();
input_channel_->Close(); input_channel_->Close();
std::vector<PvInstance> pv_data; std::vector<PvInstance> pv_data;
input_channel_->ReadAll(input_records_); input_channel_->ReadAll(input_records_);
...@@ -1073,7 +1088,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim, ...@@ -1073,7 +1088,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim,
} }
CHECK(multi_output_channel_.size() != 0); // NOLINT CHECK(multi_output_channel_.size() != 0); // NOLINT
auto fleet_ptr_ = FleetWrapper::GetInstance(); auto fleet_ptr_ = framework::FleetWrapper::GetInstance();
std::vector<std::unordered_map<uint64_t, std::vector<float>>>& std::vector<std::unordered_map<uint64_t, std::vector<float>>>&
local_map_tables = fleet_ptr_->GetLocalTable(); local_map_tables = fleet_ptr_->GetLocalTable();
local_map_tables.resize(shard_num); local_map_tables.resize(shard_num);
...@@ -1315,7 +1330,7 @@ void MultiSlotDataset::MergeByInsId() { ...@@ -1315,7 +1330,7 @@ void MultiSlotDataset::MergeByInsId() {
LOG(WARNING) << "total drop ins num: " << drop_ins_num; LOG(WARNING) << "total drop ins num: " << drop_ins_num;
results.shrink_to_fit(); results.shrink_to_fit();
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = framework::FleetWrapper::GetInstance();
std::shuffle(results.begin(), results.end(), fleet_ptr->LocalRandomEngine()); std::shuffle(results.begin(), results.end(), fleet_ptr->LocalRandomEngine());
channel_data->Open(); channel_data->Open();
channel_data->Write(std::move(results)); channel_data->Write(std::move(results));
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
#if defined PADDLE_WITH_PSCORE #if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h" #include "paddle/fluid/distributed/service/communicator.h"
...@@ -153,7 +154,20 @@ void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { ...@@ -153,7 +154,20 @@ void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_ || need_dump_param_) { if (need_dump_field_ || need_dump_param_) {
InitDumpEnv(); InitDumpEnv();
} }
#ifdef PADDLE_WITH_PSCORE
// pull dense param first
auto communicator = paddle::distributed::Communicator::GetInstance();
// for unittest which call train_from_dataset but does not call
// fleet.init_worker() first
if (communicator == nullptr) {
VLOG(0) << "MultiTrainer::InitOtherEnv Communicator is null!";
} else {
auto& recv_ctx = communicator->GetRecvCtxMap();
communicator->PullDense(recv_ctx);
VLOG(3) << "init other env done."; VLOG(3) << "init other env done.";
}
#endif
} }
Scope* MultiTrainer::GetWorkerScope(int thread_id) { Scope* MultiTrainer::GetWorkerScope(int thread_id) {
...@@ -253,6 +267,17 @@ void MultiTrainer::Finalize() { ...@@ -253,6 +267,17 @@ void MultiTrainer::Finalize() {
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
MergeDenseParam(); MergeDenseParam();
#endif #endif
#if defined PADDLE_WITH_PSCORE
auto communicator = paddle::distributed::Communicator::GetInstance();
// for unittest which does not call fleet.init_worker() first
if (communicator == nullptr) {
VLOG(0) << "MultiTrainer::Finalize communicator is null!";
} else {
communicator->_worker_ptr->flush();
VLOG(1) << "MultiTrainer::Finalize ps client flush done";
}
#endif
root_scope_->DropKids(); root_scope_->DropKids();
} }
......
...@@ -577,8 +577,12 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -577,8 +577,12 @@ class TheOnePSRuntime(RuntimeBase):
else: else:
init_params = dense_map init_params = dense_map
import paddle.distributed.fleet as fleet
if not is_test: if not is_test:
self._communicator.init_params(init_params) self._communicator.init_params(init_params)
fleet.util.barrier()
self._communicator.pull_dense(init_params)
fleet.util.barrier()
if not self._communicator.is_running(): if not self._communicator.is_running():
self._communicator.start() self._communicator.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册