未验证 提交 ca8c4f3e 编写于 作者: Z zhaocaibei123 提交者: GitHub

update dataset (#37194)

上级 54d2626a
......@@ -710,8 +710,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
VLOG(1) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
auto* communicator = Communicator::GetInstance();
return communicator->_worker_ptr->registe_client2client_msg_handler(msg_type,
handler);
// for unittest which does not call fleet.init_worker() first
if (communicator == nullptr) {
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler communicator is "
"null";
return -1;
} else {
return communicator->_worker_ptr->registe_client2client_msg_handler(
msg_type, handler);
}
}
std::future<int32_t> FleetWrapper::SendClientToClientMsg(
......
......@@ -368,20 +368,7 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
VLOG(1) << "push dense param to table " << table_id
<< " from 0' trainer done";
}
BarrierWithTable(1);
} else {
BarrierWithTable(1);
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param to table " << table_id
<< " from 0' trainer done";
}
}
std::this_thread::sleep_for(
std::chrono::milliseconds(100 + trainer_id_ * 10));
BarrierWithTable(1);
return;
}
......
......@@ -19,6 +19,10 @@
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/timer.h"
#ifdef PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/fleet.h"
#endif
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
......@@ -208,13 +212,17 @@ void DatasetImpl<T>::CreateChannel() {
// if sent message between workers, should first call this function
template <typename T>
void DatasetImpl<T>::RegisterClientToClientMsgHandler() {
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
#ifdef PADDLE_WITH_PSCORE
auto fleet_ptr = distributed::FleetWrapper::GetInstance();
#else
auto fleet_ptr = framework::FleetWrapper::GetInstance();
#endif
VLOG(1) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
VLOG(3) << "RegisterClientToClientMsgHandler done";
VLOG(1) << "RegisterClientToClientMsgHandler done";
}
static void compute_left_batch_num(const int ins_num, const int thread_num,
std::vector<std::pair<int, int>>* offset,
......@@ -523,7 +531,7 @@ void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, no data to shuffle";
return;
}
auto fleet_ptr = FleetWrapper::GetInstance();
auto fleet_ptr = framework::FleetWrapper::GetInstance();
input_channel_->Close();
std::vector<T> data;
input_channel_->ReadAll(data);
......@@ -540,11 +548,14 @@ void DatasetImpl<T>::LocalShuffle() {
}
void MultiSlotDataset::GlobalShuffle(int thread_num) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "MultiSlotDataset::GlobalShuffle() begin";
platform::Timer timeline;
timeline.Start();
auto fleet_ptr = FleetWrapper::GetInstance();
#ifdef PADDLE_WITH_PSCORE
auto fleet_ptr = distributed::FleetWrapper::GetInstance();
#else
auto fleet_ptr = framework::FleetWrapper::GetInstance();
#endif
if (!input_channel_ || input_channel_->Size() == 0) {
VLOG(3) << "MultiSlotDataset::GlobalShuffle() end, no data to shuffle";
......@@ -576,7 +587,12 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
};
auto global_shuffle_func = [this, get_client_id]() {
auto fleet_ptr = FleetWrapper::GetInstance();
#ifdef PADDLE_WITH_PSCORE
auto fleet_ptr = distributed::FleetWrapper::GetInstance();
#else
auto fleet_ptr = framework::FleetWrapper::GetInstance();
#endif
// auto fleet_ptr = framework::FleetWrapper::GetInstance();
std::vector<Record> data;
while (this->input_channel_->Read(data)) {
std::vector<paddle::framework::BinaryArchive> ars(this->trainer_num_);
......@@ -633,7 +649,6 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
#endif
}
template <typename T>
......@@ -936,7 +951,7 @@ int MultiSlotDataset::ReceiveFromClient(int msg_type, int client_id,
}
CHECK(ar.Cursor() == ar.Finish());
auto fleet_ptr = FleetWrapper::GetInstance();
auto fleet_ptr = framework::FleetWrapper::GetInstance();
// not use random because it doesn't perform well here.
// to make sure each channel get data equally, we just put data to
// channel one by one.
......@@ -976,7 +991,7 @@ void MultiSlotDataset::DynamicAdjustReadersNum(int thread_num) {
void MultiSlotDataset::PostprocessInstance() {
// divide pv instance, and merge to input_channel_
if (enable_pv_merge_) {
auto fleet_ptr = FleetWrapper::GetInstance();
auto fleet_ptr = framework::FleetWrapper::GetInstance();
std::shuffle(input_records_.begin(), input_records_.end(),
fleet_ptr->LocalRandomEngine());
input_channel_->Open();
......@@ -1014,7 +1029,7 @@ void MultiSlotDataset::PreprocessInstance() {
if (!enable_pv_merge_) { // means to use Record
this->LocalShuffle();
} else { // means to use Pv
auto fleet_ptr = FleetWrapper::GetInstance();
auto fleet_ptr = framework::FleetWrapper::GetInstance();
input_channel_->Close();
std::vector<PvInstance> pv_data;
input_channel_->ReadAll(input_records_);
......@@ -1073,7 +1088,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim,
}
CHECK(multi_output_channel_.size() != 0); // NOLINT
auto fleet_ptr_ = FleetWrapper::GetInstance();
auto fleet_ptr_ = framework::FleetWrapper::GetInstance();
std::vector<std::unordered_map<uint64_t, std::vector<float>>>&
local_map_tables = fleet_ptr_->GetLocalTable();
local_map_tables.resize(shard_num);
......@@ -1315,7 +1330,7 @@ void MultiSlotDataset::MergeByInsId() {
LOG(WARNING) << "total drop ins num: " << drop_ins_num;
results.shrink_to_fit();
auto fleet_ptr = FleetWrapper::GetInstance();
auto fleet_ptr = framework::FleetWrapper::GetInstance();
std::shuffle(results.begin(), results.end(), fleet_ptr->LocalRandomEngine());
channel_data->Open();
channel_data->Write(std::move(results));
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
#if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h"
......@@ -153,7 +154,20 @@ void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_ || need_dump_param_) {
InitDumpEnv();
}
VLOG(3) << "init other env done.";
#ifdef PADDLE_WITH_PSCORE
// pull dense param first
auto communicator = paddle::distributed::Communicator::GetInstance();
// for unittest which call train_from_dataset but does not call
// fleet.init_worker() first
if (communicator == nullptr) {
VLOG(0) << "MultiTrainer::InitOtherEnv Communicator is null!";
} else {
auto& recv_ctx = communicator->GetRecvCtxMap();
communicator->PullDense(recv_ctx);
VLOG(3) << "init other env done.";
}
#endif
}
Scope* MultiTrainer::GetWorkerScope(int thread_id) {
......@@ -253,6 +267,17 @@ void MultiTrainer::Finalize() {
#ifdef PADDLE_WITH_HETERPS
MergeDenseParam();
#endif
#if defined PADDLE_WITH_PSCORE
auto communicator = paddle::distributed::Communicator::GetInstance();
// for unittest which does not call fleet.init_worker() first
if (communicator == nullptr) {
VLOG(0) << "MultiTrainer::Finalize communicator is null!";
} else {
communicator->_worker_ptr->flush();
VLOG(1) << "MultiTrainer::Finalize ps client flush done";
}
#endif
root_scope_->DropKids();
}
......
......@@ -577,8 +577,12 @@ class TheOnePSRuntime(RuntimeBase):
else:
init_params = dense_map
import paddle.distributed.fleet as fleet
if not is_test:
self._communicator.init_params(init_params)
fleet.util.barrier()
self._communicator.pull_dense(init_params)
fleet.util.barrier()
if not self._communicator.is_running():
self._communicator.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册