未验证 提交 576236a0 编写于 作者: S Sing_chan 提交者: GitHub

format all files in fluid using new config (#43776)

上级 c40858c3

要显示的变更太多。

To preserve performance only 1000 of 1000+ files are displayed.
......@@ -41,10 +41,10 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
}
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
return std::all_of(tensors.cbegin(), tensors.cend(),
[&](const phi::DenseTensor& t) {
return platform::is_gpu_place(t.place());
});
return std::all_of(
tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) {
return platform::is_gpu_place(t.place());
});
}
} // namespace distributed
......
......@@ -28,7 +28,8 @@ HcclReduceOp ToHCCLRedType(ReduceOp reduction) {
};
auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ(
it != red_type.end(), true,
it != red_type.end(),
true,
platform::errors::InvalidArgument("Invalid hccl reduction. "
"Must be Min | Max | Prod | Sum"));
return it->second;
......
......@@ -67,11 +67,13 @@ class NPUEventManager {
if (!is_created_) {
CreateEvent(device_index);
}
PADDLE_ENFORCE_EQ(device_index, device_index_,
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"NPUDeviceContext's device %d does not match"
"Event's device %d",
device_index, device_index_));
device_index,
device_index_));
platform::NPUDeviceGuard guard(device_index_);
platform::NPUEventRecord(event_, ctx.stream());
......@@ -89,11 +91,13 @@ class NPUEventManager {
void Block(const paddle::platform::NPUDeviceContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index, device_index_,
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"CUDADeviceContext's device %d does not match"
"Event's device %d",
device_index, device_index_));
device_index,
device_index_));
platform::NPUDeviceGuard guard(device_index_);
platform::NPUStreamWaitEvent(ctx.stream(), event_);
}
......@@ -126,12 +130,13 @@ class HCCLCommManager {
}
}
static std::shared_ptr<HCCLCommManager> Create(int num_ranks, int rank,
static std::shared_ptr<HCCLCommManager> Create(int num_ranks,
int rank,
HcclRootInfo* comm_id,
HcclComm hccl_comm) {
auto hccl_manager = std::make_shared<HCCLCommManager>();
auto ret = platform::dynload::HcclCommInitRootInfo(num_ranks, comm_id, rank,
&hccl_comm);
auto ret = platform::dynload::HcclCommInitRootInfo(
num_ranks, comm_id, rank, &hccl_comm);
using __NPU_STATUS_TYPE__ = decltype(ret);
constexpr auto __success_type__ =
platform::details::NPUStatusType<__NPU_STATUS_TYPE__>::kSuccess;
......
......@@ -27,7 +27,8 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction) {
{ReduceOp::PRODUCT, ncclProd},
};
auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ(it != red_type.end(), true,
PADDLE_ENFORCE_EQ(it != red_type.end(),
true,
platform::errors::InvalidArgument(
"Invalid nccl reduction. Must be ncclMin | ncclMax | "
"ncclProd | ncclSum"));
......
......@@ -47,14 +47,16 @@
namespace paddle {
namespace distributed {
#define NCCLCHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
platform::dynload::ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
#define NCCLCHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
platform::dynload::ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
......@@ -107,11 +109,13 @@ class EventManager {
if (!is_created_) {
CreateEvent(device_index);
}
PADDLE_ENFORCE_EQ(device_index, device_index_,
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"CUDADeviceContext's device %d does not match"
"Event's device %d",
device_index, device_index_));
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_CUDA
......@@ -156,11 +160,13 @@ class EventManager {
void Block(const paddle::platform::CUDADeviceContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index, device_index_,
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"CUDADeviceContext's device %d does not match"
"Event's device %d",
device_index, device_index_));
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP
......@@ -213,11 +219,12 @@ class NCCLCommManager {
}
}
static std::shared_ptr<NCCLCommManager> Create(int num_ranks, int rank,
static std::shared_ptr<NCCLCommManager> Create(int num_ranks,
int rank,
ncclUniqueId comm_id) {
auto nccl_manager = std::make_shared<NCCLCommManager>();
NCCLCHECK(platform::dynload::ncclCommInitRank(&(nccl_manager->nccl_comm_),
num_ranks, comm_id, rank));
NCCLCHECK(platform::dynload::ncclCommInitRank(
&(nccl_manager->nccl_comm_), num_ranks, comm_id, rank));
nccl_manager->nccl_id_ = comm_id;
nccl_manager->rank_ = rank;
......
......@@ -35,7 +35,9 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) {
void ProcessGroup::Task::Synchronize() {}
ProcessGroup::ProcessGroup(int rank, int size, const platform::Place& place,
ProcessGroup::ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid)
: rank_(rank), size_(size), place_(place), gid_(gid) {
if (gid != IGNORE_ID) {
......
......@@ -53,7 +53,8 @@ class ProcessGroup {
public:
class Task {
public:
Task(int rank, const std::vector<phi::DenseTensor>& inputTensors,
Task(int rank,
const std::vector<phi::DenseTensor>& inputTensors,
CommType opType = CommType::UNKNOWN);
virtual ~Task();
......@@ -68,7 +69,9 @@ class ProcessGroup {
bool is_completed_ = false;
};
explicit ProcessGroup(int rank, int size, const platform::Place& place,
explicit ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid);
virtual ~ProcessGroup() {}
......@@ -113,7 +116,8 @@ class ProcessGroup {
}
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor&,
int, int,
int,
int,
int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
......
......@@ -165,8 +165,11 @@ ProcessGroupGloo::GlooTask::GlooTask(
: ProcessGroup::Task(rank, inputs, comm_type) {}
ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<distributed::Store>& store, int rank, int world_size,
const platform::Place& place, int gid,
const std::shared_ptr<distributed::Store>& store,
int rank,
int world_size,
const platform::Place& place,
int gid,
const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size, place, gid),
_tag(0),
......@@ -182,7 +185,9 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
BroadcastGlooTask(const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
int rank, int root, uint32_t tag)
int rank,
int root,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST),
_context(context),
_root(root),
......@@ -214,23 +219,26 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, const BroadcastOptions& opts) {
std::vector<phi::DenseTensor>& outputs,
const BroadcastOptions& opts) {
auto root = opts.source_rank;
std::unique_ptr<BroadcastGlooTask> task;
auto tag = next_tag();
auto context = get_context();
task = std::make_unique<BroadcastGlooTask>(context, inputs, outputs, rank_,
root, tag);
task = std::make_unique<BroadcastGlooTask>(
context, inputs, outputs, rank_, root, tag);
task->Run();
return task;
}
class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
public:
AllreduceGlooTask(int rank, const std::shared_ptr<gloo::Context>& context,
AllreduceGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op, uint32_t tag)
ReduceOp reduce_op,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE),
_context(context),
_inputs(inputs),
......@@ -274,12 +282,13 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, const AllreduceOptions& opts) {
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts) {
auto tag = next_tag();
std::shared_ptr<GlooTask> task;
auto context = get_context();
task = std::make_shared<AllreduceGlooTask>(rank_, context, inputs, outputs,
opts.reduce_op, tag);
task = std::make_shared<AllreduceGlooTask>(
rank_, context, inputs, outputs, opts.reduce_op, tag);
task->Run();
return task;
}
......@@ -287,8 +296,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
class BarrierGlooTask : public ProcessGroupGloo::GlooTask {
public:
BarrierGlooTask(int rank, const std::shared_ptr<gloo::Context>& context)
: ProcessGroupGloo::GlooTask(rank, std::vector<phi::DenseTensor>{},
CommType::BARRIER),
: ProcessGroupGloo::GlooTask(
rank, std::vector<phi::DenseTensor>{}, CommType::BARRIER),
_context(context) {}
void Run() override { _do_barrier(); }
......@@ -313,7 +322,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier(
class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
public:
AllgatherGlooTask(int rank, const std::shared_ptr<gloo::Context>& context,
AllgatherGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
uint32_t tag)
......@@ -348,18 +358,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::shared_ptr<AllgatherGlooTask> task;
auto tag = next_tag();
auto context = get_context();
task = std::make_shared<AllgatherGlooTask>(rank_, context, in_tensors,
out_tensors, tag);
task = std::make_shared<AllgatherGlooTask>(
rank_, context, in_tensors, out_tensors, tag);
task->Run();
return task;
}
class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
public:
ReduceGlooTask(int rank, const std::shared_ptr<gloo::Context>& context,
ReduceGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op, int dst, uint32_t tag)
ReduceOp reduce_op,
int dst,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::REDUCE),
_context(context),
_inputs(inputs),
......@@ -407,22 +420,26 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, const ReduceOptions& opts) {
std::vector<phi::DenseTensor>& outputs,
const ReduceOptions& opts) {
std::shared_ptr<ReduceGlooTask> task;
auto tag = next_tag();
auto context = get_context();
task = std::make_shared<ReduceGlooTask>(rank_, context, inputs, outputs,
opts.reduce_op, opts.root_rank, tag);
task = std::make_shared<ReduceGlooTask>(
rank_, context, inputs, outputs, opts.reduce_op, opts.root_rank, tag);
task->Run();
return task;
}
class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
public:
ScatterGlooTask(int rank, const std::shared_ptr<gloo::Context>& context,
ScatterGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
int src, int size, uint32_t tag)
int src,
int size,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::SCATTER),
_context(context),
_inputs(inputs),
......@@ -458,7 +475,8 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const ScatterOptions& opts) {
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) {
std::shared_ptr<ScatterGlooTask> task;
auto tag = next_tag();
auto context = get_context();
......@@ -487,7 +505,8 @@ ProcessGroupGloo::createDefaultDevice() {
std::array<char, HOST_NAME_MAX> hostname{};
auto ret = ::gethostname(hostname.data(), HOST_NAME_MAX);
PADDLE_ENFORCE_EQ(
ret, 0,
ret,
0,
platform::errors::Fatal("Get hostname error for createDefaultDevice."));
::addrinfo* result;
result = tcputils::get_addr_info(hostname.data(), "", 0, AF_UNSPEC);
......
......@@ -101,8 +101,11 @@ class ProcessGroupGloo : public ProcessGroup {
};
explicit ProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, int rank,
int world_size, const platform::Place& place, int gid,
const std::shared_ptr<paddle::distributed::Store>& store,
int rank,
int world_size,
const platform::Place& place,
int gid,
std::shared_ptr<GlooOptions> options);
~ProcessGroupGloo() = default;
......
......@@ -45,14 +45,18 @@ void SyncDefaultStream(
}
std::shared_ptr<ProcessGroupHCCL::HCCLTask> ProcessGroupHCCL::CreateTask(
std::vector<Place> places, int rank, CommType comm_type,
std::vector<Place> places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupHCCL::HCCLTask>(places, rank, comm_type,
inputs);
return std::make_shared<ProcessGroupHCCL::HCCLTask>(
places, rank, comm_type, inputs);
}
ProcessGroupHCCL::HCCLTask::HCCLTask(
const std::vector<Place>& places, int rank, CommType CommType,
const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size());
......@@ -99,8 +103,10 @@ bool ProcessGroupHCCL::HCCLTask::Wait(std::chrono::milliseconds timeout) {
void ProcessGroupHCCL::HCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupHCCL::ProcessGroupHCCL(const std::shared_ptr<Store>& store,
int rank, int size,
const platform::Place& place, int gid)
int rank,
int size,
const platform::Place& place,
int gid)
: ProcessGroup(rank, size, place, gid), store_(store) {
platform::SetNPUDeviceId(place_.device);
}
......@@ -127,7 +133,8 @@ void ProcessGroupHCCL::BroadcastUniqueHCCLID(
// create HCCLManager cache for places_key
void ProcessGroupHCCL::CreateHCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(places_key.empty(), false,
PADDLE_ENFORCE_EQ(places_key.empty(),
false,
platform::errors::PreconditionNotMet(
"Not able to create/get the HCCL Communicator since "
"the NPU place are not known"));
......@@ -155,8 +162,8 @@ void ProcessGroupHCCL::CreateHCCLManagerCache(
std::unique_ptr<HcclComm[]> comms(new HcclComm[places.size()]);
for (size_t i = 0; i < places.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId());
hccl_comms[i] = HCCLCommManager::Create(GetSize(), GetRank(), &hccl_id,
comms.get() + i);
hccl_comms[i] = HCCLCommManager::Create(
GetSize(), GetRank(), &hccl_id, comms.get() + i);
dev_ctx[i].reset(new NPUDeviceContext(places[i]));
}
......@@ -172,7 +179,9 @@ void ProcessGroupHCCL::CreateHCCLManagerCache(
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, Fn fn, CommType op_type) {
std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
......@@ -218,13 +227,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::AllReduce(
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const AllreduceOptions& opts) {
return Collective(
in_tensors, out_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, HcclComm comm,
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
HcclComm comm,
const aclrtStream& stream) {
return platform::dynload::HcclAllReduce(
input.data(), output.data(), input.numel(),
input.data(),
output.data(),
input.numel(),
platform::ToHCCLDataType(input.dtype()),
ToHCCLRedType(opts.reduce_op), comm, stream);
ToHCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE);
}
......@@ -239,18 +255,29 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Broadcast(
// CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, HcclComm comm,
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
HcclComm comm,
const aclrtStream& stream) {
int root = opts.source_rank * in_tensors.size() + opts.source_root;
if (rank_ == root) {
return platform::dynload::HcclBroadcast(
input.data(), input.numel(),
platform::ToHCCLDataType(input.dtype()), root, comm, stream);
input.data(),
input.numel(),
platform::ToHCCLDataType(input.dtype()),
root,
comm,
stream);
} else {
return platform::dynload::HcclBroadcast(
output.data(), output.numel(),
platform::ToHCCLDataType(output.dtype()), root, comm, stream);
output.data(),
output.numel(),
platform::ToHCCLDataType(output.dtype()),
root,
comm,
stream);
}
},
CommType::BROADCAST);
......
......@@ -44,7 +44,9 @@ class ProcessGroupHCCL : public ProcessGroup {
class HCCLTask : public ProcessGroup::Task,
public std::enable_shared_from_this<HCCLTask> {
public:
HCCLTask(const std::vector<Place>& places, int rank, CommType CommType,
HCCLTask(const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
bool IsCompleted();
......@@ -69,8 +71,11 @@ class ProcessGroupHCCL : public ProcessGroup {
private:
};
ProcessGroupHCCL(const std::shared_ptr<Store>& store, int rank, int size,
const platform::Place& place, int gid);
ProcessGroupHCCL(const std::shared_ptr<Store>& store,
int rank,
int size,
const platform::Place& place,
int gid);
const std::string GetBackendName() const override {
return std::string(HCCL_BACKEND_NAME);
......@@ -88,7 +93,9 @@ class ProcessGroupHCCL : public ProcessGroup {
protected:
virtual std::shared_ptr<ProcessGroupHCCL::HCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
std::vector<Place> places,
int rank,
CommType opType,
const std::vector<phi::DenseTensor>& inputs);
std::shared_ptr<Store> store_;
......@@ -107,7 +114,8 @@ class ProcessGroupHCCL : public ProcessGroup {
std::set<int> used_place_ids_;
private:
void BcastHCCLId(std::vector<HcclRootInfo>& hccl_ids, int root, // NOLINT
void BcastHCCLId(std::vector<HcclRootInfo>& hccl_ids,
int root, // NOLINT
int server_fd);
void BroadcastUniqueHCCLID(std::vector<HcclRootInfo>& hccl_ids); // NOLINT
......@@ -116,7 +124,8 @@ class ProcessGroupHCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn, CommType op_type);
Fn fn,
CommType op_type);
void CreateHCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
......
......@@ -32,8 +32,8 @@ int ProcessGroupHeter::recv_count = 0;
std::shared_ptr<ProcessGroupHeter::HeterTask> ProcessGroupHeter::CreateTask(
int rank, CommType comm_type, const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupHeter::HeterTask>(rank, comm_type,
inputs);
return std::make_shared<ProcessGroupHeter::HeterTask>(
rank, comm_type, inputs);
}
ProcessGroupHeter::HeterTask::HeterTask(
......@@ -49,11 +49,19 @@ bool ProcessGroupHeter::HeterTask::Wait(std::chrono::milliseconds timeout) {
return true;
}
ProcessGroupHeter::ProcessGroupHeter(
const std::shared_ptr<Store>& store, int rank, int size,
const platform::Place& place, int gid, int local_rank, int local_size,
int gloo_rank, int gloo_size, bool with_switch, std::string switch_endpoint,
int src_rank, int dst_rank)
ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store,
int rank,
int size,
const platform::Place& place,
int gid,
int local_rank,
int local_size,
int gloo_rank,
int gloo_size,
bool with_switch,
std::string switch_endpoint,
int src_rank,
int dst_rank)
: ProcessGroup(rank, size, place, gid),
store_(store),
local_rank_(local_rank),
......@@ -66,11 +74,11 @@ ProcessGroupHeter::ProcessGroupHeter(
dst_rank_(dst_rank) {
return;
#if defined(PADDLE_WITH_NCCL)
inner_pg_ = std::make_shared<ProcessGroupNCCL>(store, local_rank, local_size,
place_, IGNORE_ID);
inner_pg_ = std::make_shared<ProcessGroupNCCL>(
store, local_rank, local_size, place_, IGNORE_ID);
#elif defined(PADDLE_WITH_ASCEND_CL)
inner_pg_ = std::make_shared<ProcessGroupHCCL>(store, local_rank, local_size,
place_, IGNORE_ID);
inner_pg_ = std::make_shared<ProcessGroupHCCL>(
store, local_rank, local_size, place_, IGNORE_ID);
#else
PADDLE_THROW(platform::errors::Fatal(
"ProcessGroupHeter only supports NCCL and HCCL now.");
......@@ -94,13 +102,16 @@ static void _do_add(T* dst, T* src, size_t size) {
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const AllreduceOptions& opts) {
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts) {
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
#endif
......@@ -128,10 +139,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::vector<int64_t> send_size;
send_size.push_back(dense_cpu_tensor.numel());
int ret = client_->Send(
gid_, {dense_cpu_tensor.name()}, send_size, dense_cpu_tensor.data(),
gid_,
{dense_cpu_tensor.name()},
send_size,
dense_cpu_tensor.data(),
dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet(
"Send to the switch module error."));
phi::DenseTensor cpu_tensor2;
......@@ -139,11 +154,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
dense_cpu_tensor.dtype(), dense_cpu_tensor.numel());
dense_cpu_tensor.dtype(),
dense_cpu_tensor.numel());
ret = client_->Recv(
gid_, {dense_cpu_tensor.name()}, cpu_tensor2.data(),
gid_,
{dense_cpu_tensor.name()},
cpu_tensor2.data(),
cpu_tensor2.numel() * framework::DataTypeSize(cpu_tensor2.dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet(
"Recv from the switch module error."));
......@@ -192,13 +211,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const BroadcastOptions& opts) {
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts) {
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
#endif
......@@ -226,19 +248,25 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
std::vector<int64_t> send_size;
send_size.push_back(dense_cpu_tensor.numel());
int ret = client_->Send(
gid_, {dense_cpu_tensor.name()}, send_size,
gid_,
{dense_cpu_tensor.name()},
send_size,
dense_cpu_tensor.data(),
dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet(
"Send to the switch module error."));
} else {
int ret = client_->Recv(
gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor.data(),
gid_,
{dense_cpu_tensor.name()},
dense_cpu_tensor.data(),
dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
}
......@@ -261,7 +289,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(
std::vector<phi::DenseTensor>& in_tensors, int peer) {
PADDLE_ENFORCE_EQ(
in_tensors.size(), 1,
in_tensors.size(),
1,
platform::errors::PreconditionNotMet(
"For each send operation, there can only be one tensor to send."));
// Copy Tensor to cpu
......@@ -269,7 +298,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(
phi::DenseTensor cpu_tensor;
auto& gpu_tensor = in_tensors[0];
framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor);
PADDLE_ENFORCE_EQ(with_switch_, true,
PADDLE_ENFORCE_EQ(with_switch_,
true,
platform::errors::PreconditionNotMet(
"Gloo does not support the send operation."));
auto end = std::chrono::high_resolution_clock::now();
......@@ -289,10 +319,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(
std::string tensor_name = std::to_string(gid_) + "_id_" + std::to_string(id) +
std::string("_") + std::to_string(send_count++);
VLOG(2) << "tensor_name:" << tensor_name;
int ret = client_->Send(gid_, {tensor_name}, send_size, cpu_tensor.data(),
tensor_size);
int ret = client_->Send(
gid_, {tensor_name}, send_size, cpu_tensor.data(), tensor_size);
PADDLE_ENFORCE_EQ(
ret, 0,
ret,
0,
platform::errors::PreconditionNotMet("Send to the switch module error."));
return CreateTask(rank_, CommType::SEND, in_tensors);
}
......@@ -300,7 +331,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv(
std::vector<phi::DenseTensor>& out_tensors, int peer) {
PADDLE_ENFORCE_EQ(
out_tensors.size(), 1,
out_tensors.size(),
1,
platform::errors::PreconditionNotMet(
"For each rece operation, there can only be one tensor to receive."));
......@@ -311,7 +343,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv(
cpu_tensor.set_layout(gpu_tensor.layout());
cpu_tensor.mutable_data(platform::CPUPlace(), gpu_tensor.dtype());
PADDLE_ENFORCE_EQ(with_switch_, true,
PADDLE_ENFORCE_EQ(with_switch_,
true,
platform::errors::PreconditionNotMet(
"Gloo does not support the send operation."));
// recv from switch
......@@ -323,9 +356,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv(
VLOG(2) << "tensor_name: " << tensor_name;
auto start = std::chrono::high_resolution_clock::now();
int ret = client_->Recv(
gid_, {tensor_name}, cpu_tensor.data(),
gid_,
{tensor_name},
cpu_tensor.data(),
cpu_tensor.numel() * framework::DataTypeSize(cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet(
"receive to the switch module error."));
auto end = std::chrono::high_resolution_clock::now();
......
......@@ -66,7 +66,8 @@ class ProcessGroupHeter : public ProcessGroup {
class HeterTask : public ProcessGroup::Task,
public std::enable_shared_from_this<HeterTask> {
public:
HeterTask(int rank, CommType CommType,
HeterTask(int rank,
CommType CommType,
const std::vector<phi::DenseTensor>&);
bool IsCompleted();
......@@ -80,22 +81,32 @@ class ProcessGroupHeter : public ProcessGroup {
virtual ~HeterTask();
};
ProcessGroupHeter(const std::shared_ptr<Store>& store, int rank, int size,
const platform::Place& place, int gid, int local_rank,
int local_size, int gloo_rank, int gloo_size,
bool with_switch, std::string switch_endpoints,
int src_rank, int dst_rank);
ProcessGroupHeter(const std::shared_ptr<Store>& store,
int rank,
int size,
const platform::Place& place,
int gid,
int local_rank,
int local_size,
int gloo_rank,
int gloo_size,
bool with_switch,
std::string switch_endpoints,
int src_rank,
int dst_rank);
const std::string GetBackendName() const override {
return std::string(HETER_BACKEND_NAME);
}
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>&, std::vector<phi::DenseTensor>&,
std::vector<phi::DenseTensor>&,
std::vector<phi::DenseTensor>&,
const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>&, std::vector<phi::DenseTensor>&,
std::vector<phi::DenseTensor>&,
std::vector<phi::DenseTensor>&,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(
......
......@@ -42,14 +42,18 @@ void SyncDefaultStream(
}
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
std::vector<Place> places, int rank, CommType comm_type,
std::vector<Place> places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>(places, rank, comm_type,
inputs);
return std::make_shared<ProcessGroupNCCL::NCCLTask>(
places, rank, comm_type, inputs);
}
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places, int rank, CommType CommType,
const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size());
......@@ -109,8 +113,10 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int size,
const platform::Place& place, int gid)
int rank,
int size,
const platform::Place& place,
int gid)
: ProcessGroup(rank, size, place, gid), store_(store) {
platform::SetDeviceId(place_.device);
}
......@@ -139,7 +145,8 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(
// create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(places_key.empty(), false,
PADDLE_ENFORCE_EQ(places_key.empty(),
false,
platform::errors::PreconditionNotMet(
"Not able to create/get the NCCL Communicator since "
"the GPU place are not known"));
......@@ -190,7 +197,9 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, Fn fn, CommType op_type) {
std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
......@@ -237,7 +246,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
template <typename Fn>
void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
phi::DenseTensor* out, Fn fn,
phi::DenseTensor* out,
Fn fn,
CommType op_type) {
std::vector<Place> places;
places.push_back(in->place());
......@@ -274,7 +284,9 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::vector<phi::DenseTensor>& tensors, Fn fn, int dst_rank,
std::vector<phi::DenseTensor>& tensors,
Fn fn,
int dst_rank,
CommType op_type) {
const auto places = GetPlaceList(tensors);
const auto key = GetKeyFromPlaces(places);
......@@ -321,38 +333,57 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const AllreduceOptions& opts) {
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const phi::DenseTensor& input, phi::DenseTensor& output,
ncclComm_t comm, const gpuStream_t& stream) {
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllReduce(
input.data(), output.data(), input.numel(),
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op), comm, stream);
ToNCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const BroadcastOptions& opts) {
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, ncclComm_t comm,
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
return platform::dynload::ncclBroadcast(
input.data(), output.data(), input.numel(),
platform::ToNCCLDataType(input.type()), root, comm, stream);
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
root,
comm,
stream);
},
CommType::BROADCAST);
}
......@@ -381,22 +412,26 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
void CheckTensorsInDifferentDevices(
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
PADDLE_ENFORCE_EQ(
tensors.size() == 0, false,
tensors.size() == 0,
false,
platform::errors::InvalidArgument("Tensor list must be nonempty."));
PADDLE_ENFORCE_LE(
tensors.size(), num_devices,
tensors.size(),
num_devices,
platform::errors::InvalidArgument(
"Tensor list mustn't be larger than the number of available GPUs."));
std::set<Place> used_devices;
for (const auto& t : tensors) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()), true,
PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()),
true,
platform::errors::InvalidArgument(
"Tensors must be CUDA and dense tensor."));
const auto inserted = used_devices.insert(t.place()).second;
PADDLE_ENFORCE_EQ(inserted, true,
PADDLE_ENFORCE_EQ(inserted,
true,
platform::errors::InvalidArgument(
"Tensors must be on distinct GPU devices."));
}
......@@ -408,13 +443,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& input, ncclComm_t comm, const gpuStream_t& stream,
[&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return platform::dynload::ncclSend(
input.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()), dst_rank, comm, stream);
input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
dst_rank, CommType::SEND);
dst_rank,
CommType::SEND);
return task;
}
......@@ -424,13 +466,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream,
[&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return platform::dynload::ncclRecv(
output.data(), output.numel(),
platform::ToNCCLDataType(output.dtype()), src_rank, comm, stream);
output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
src_rank, CommType::RECV);
src_rank,
CommType::RECV);
return task;
}
......@@ -448,13 +497,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
auto task = PointToPoint(
shared_tensors,
[&](phi::DenseTensor& input, ncclComm_t comm, const gpuStream_t& stream,
[&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return platform::dynload::ncclSend(
input.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()), dst_rank, comm, stream);
input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
dst_rank, CommType::SEND);
dst_rank,
CommType::SEND);
return task;
}
......@@ -471,13 +527,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
auto task = PointToPoint(
shared_tensors,
[&](phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream,
[&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return platform::dynload::ncclRecv(
output.data(), output.numel(),
platform::ToNCCLDataType(output.dtype()), src_rank, comm, stream);
output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
src_rank, CommType::RECV);
src_rank,
CommType::RECV);
return task;
}
......@@ -485,23 +548,33 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const phi::DenseTensor& input, phi::DenseTensor& output,
ncclComm_t comm, const gpuStream_t& stream) {
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
input.data(), output.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()), comm, stream);
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER);
}
void* GetPointerByOffset(void* raw_pointer, size_t offset,
void* GetPointerByOffset(void* raw_pointer,
size_t offset,
experimental::DataType type) {
if (type == experimental::DataType::FLOAT32) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
......@@ -529,26 +602,37 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, ncclComm_t comm,
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
size_t offset = 0;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_, platform::ToNCCLDataType(input.dtype()), i,
comm, stream));
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output.data(), offset, input.dtype()),
input.numel() / size_, platform::ToNCCLDataType(input.dtype()), i,
comm, stream));
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += input.numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
......@@ -558,34 +642,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const ReduceOptions& opts) {
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const phi::DenseTensor& input, phi::DenseTensor& output,
ncclComm_t comm, const gpuStream_t& stream) {
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
input.data(), output.data(), input.numel(),
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op), opts.root_rank, comm, stream));
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
},
CommType::REDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const ScatterOptions& opts) {
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, ncclComm_t comm,
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
size_t offset = 0;
if (rank_ == opts.root_rank) {
......@@ -593,19 +693,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_, platform::ToNCCLDataType(input.dtype()),
i, comm, stream));
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += input.numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(), input.numel() / size_,
platform::ToNCCLDataType(input.dtype()), opts.root_rank, comm,
output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(), input.numel() / size_,
platform::ToNCCLDataType(input.dtype()), opts.root_rank, comm,
output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
}
},
......
......@@ -54,7 +54,9 @@ class ProcessGroupNCCL : public ProcessGroup {
class NCCLTask : public ProcessGroup::Task,
public std::enable_shared_from_this<NCCLTask> {
public:
NCCLTask(const std::vector<Place>& places, int rank, CommType CommType,
NCCLTask(const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
bool IsCompleted();
......@@ -80,8 +82,11 @@ class ProcessGroupNCCL : public ProcessGroup {
private:
};
ProcessGroupNCCL(const std::shared_ptr<Store>& store, int rank, int size,
const platform::Place& place, int gid);
ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank,
int size,
const platform::Place& place,
int gid);
const std::string GetBackendName() const override {
return std::string(NCCL_BACKEND_NAME);
......@@ -107,11 +112,13 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors,
int dst_rank, int offset,
int dst_rank,
int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> Recv_Partial(phi::DenseTensor& tensors,
int src_rank, int offset,
int src_rank,
int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
......@@ -134,7 +141,9 @@ class ProcessGroupNCCL : public ProcessGroup {
protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
std::vector<Place> places,
int rank,
CommType opType,
const std::vector<phi::DenseTensor>& inputs);
protected:
......@@ -153,7 +162,8 @@ class ProcessGroupNCCL : public ProcessGroup {
std::set<int> used_place_ids_;
private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root, // NOLINT
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids,
int root, // NOLINT
int server_fd);
void BroadcastUniqueNCCLID(std::vector<ncclUniqueId>& nccl_ids); // NOLINT
......@@ -162,16 +172,21 @@ class ProcessGroupNCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn, CommType op_type);
Fn fn,
CommType op_type);
template <typename Fn>
void Collective(const phi::DenseTensor*, phi::DenseTensor*, Fn fn,
void Collective(const phi::DenseTensor*,
phi::DenseTensor*,
Fn fn,
CommType op_type);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<phi::DenseTensor>& tensors, // NOLINT
Fn fn, int dst_rank, CommType op_type);
Fn fn,
int dst_rank,
CommType op_type);
void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
......
......@@ -40,7 +40,8 @@ using IntArray =
using Backend = paddle::experimental::Backend;
std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
const std::vector<Tensor>, const std::vector<bool> &is_sparse_gradient,
const std::vector<Tensor>,
const std::vector<bool> &is_sparse_gradient,
const std::vector<size_t> &group_size_limits,
const std::vector<int64_t> &tensor_indices = {});
......
......@@ -21,20 +21,27 @@ namespace distributed {
// AfsClient impl
int AfsClient::initialize(const FsClientParameter& fs_client_param) {
// temporarily implemented with hdfs-client
return initialize(fs_client_param.hadoop_bin(), fs_client_param.uri(),
fs_client_param.user(), fs_client_param.passwd(),
return initialize(fs_client_param.hadoop_bin(),
fs_client_param.uri(),
fs_client_param.user(),
fs_client_param.passwd(),
fs_client_param.buffer_size());
}
int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri,
const std::string& user, const std::string& passwd,
int AfsClient::initialize(const std::string& hadoop_bin,
const std::string& uri,
const std::string& user,
const std::string& passwd,
int buffer_size_param) {
return initialize(
hadoop_bin, uri,
hadoop_bin,
uri,
paddle::string::format_string("%s,%s", user.c_str(), passwd.c_str()),
buffer_size_param);
}
int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri,
const std::string& ugi, int buffer_size_param) {
int AfsClient::initialize(const std::string& hadoop_bin,
const std::string& uri,
const std::string& ugi,
int buffer_size_param) {
// temporarily implemented with hdfs-client
size_t buffer_size = 1L << 25; // 32MB
if (buffer_size_param > static_cast<int>(buffer_size)) {
......@@ -44,7 +51,9 @@ int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri,
paddle::framework::hdfs_set_command(paddle::string::format_string(
"2>>./hdfs_err.log %s fs -Dfs.default.name=%s -Dhadoop.job.ugi=%s "
"-Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=300000",
hadoop_bin.c_str(), uri.c_str(), ugi.c_str()));
hadoop_bin.c_str(),
uri.c_str(),
ugi.c_str()));
return 0;
}
......
......@@ -129,11 +129,15 @@ class AfsClient {
AfsClient(const AfsClient&) = delete;
int initialize(const FsClientParameter& fs_client_param);
int initialize(const std::string& hadoop_bin, const std::string& uri,
const std::string& user, const std::string& passwd,
int initialize(const std::string& hadoop_bin,
const std::string& uri,
const std::string& user,
const std::string& passwd,
int buffer_size_param = (1L << 25));
int initialize(const std::string& hadoop_bin,
const std::string& uri,
const std::string& ugi,
int buffer_size_param = (1L << 25));
int initialize(const std::string& hadoop_bin, const std::string& uri,
const std::string& ugi, int buffer_size_param = (1L << 25));
// open file in 'w' or 'r'
std::shared_ptr<FsReadChannel> open_r(const FsChannelConfig& config,
......
......@@ -25,8 +25,10 @@ class TopkCalculator {
_shard_max_size = _total_max_size / shard_num;
_shard_max_size = _shard_max_size > 1 ? _shard_max_size : 1;
for (int i = 0; i < shard_num; ++i) {
_mpq.emplace(i, std::priority_queue<double, std::vector<double>,
std::greater<double>>());
_mpq.emplace(i,
std::priority_queue<double,
std::vector<double>,
std::greater<double>>());
}
}
~TopkCalculator() {}
......@@ -58,8 +60,9 @@ class TopkCalculator {
}
private:
std::unordered_map<int, std::priority_queue<double, std::vector<double>,
std::greater<double>>>
std::unordered_map<
int,
std::priority_queue<double, std::vector<double>, std::greater<double>>>
_mpq;
int _shard_num;
size_t _total_max_size;
......
......@@ -50,8 +50,10 @@ void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
const framework::ProgramDesc& program, framework::Scope* scope,
int64_t num_micro_batches, const platform::Place& place,
const framework::ProgramDesc& program,
framework::Scope* scope,
int64_t num_micro_batches,
const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
......@@ -60,8 +62,9 @@ void Carrier::Init(
root_scope_ = scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
"root_scope can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
root_scope_,
platform::errors::InvalidArgument("root_scope can not be nullptr"));
minibatch_scope_ = &root_scope_->NewScope();
microbatch_scopes_.resize(num_micro_batches);
for (int i = 0; i < num_micro_batches; ++i) {
......@@ -87,7 +90,8 @@ void Carrier::Release() {
Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
void Carrier::CopyParameters(
int microbatch_id, const framework::ProgramDesc& program,
int microbatch_id,
const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars) {
auto& global_block = program.Block(0);
......@@ -119,7 +123,8 @@ void Carrier::CopyParameters(
bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
PADDLE_ENFORCE_EQ(
interceptor_message.ctrl_message(), false,
interceptor_message.ctrl_message(),
false,
platform::errors::Fatal(
"Control message should be only send inter rank using message bus."));
int64_t dst_id = interceptor_message.dst_id();
......@@ -130,7 +135,8 @@ bool Carrier::EnqueueInterceptorMessage(
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
PADDLE_ENFORCE_NE(iter,
interceptor_idx_to_interceptor_.end(),
platform::errors::InvalidArgument(
"Cannot find interceptor instance for interceptor "
"id %lld. Wrong dst? Call before init?",
......@@ -149,7 +155,8 @@ void Carrier::WakeUp() {
}
void Carrier::Start() {
PADDLE_ENFORCE_EQ(is_init_, true,
PADDLE_ENFORCE_EQ(is_init_,
true,
platform::errors::PreconditionNotMet(
"Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) {
......@@ -199,10 +206,12 @@ bool Carrier::Send(const InterceptorMessage& msg) {
int64_t src_rank = GetRank(src_id);
int64_t dst_rank = GetRank(dst_id);
PADDLE_ENFORCE_EQ(
src_rank, rank_,
src_rank,
rank_,
platform::errors::Fatal("The source rank id %lld, which is not equal to "
"the carrier rank id %lld.",
src_rank, rank_));
src_rank,
rank_));
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
......@@ -218,7 +227,8 @@ bool Carrier::Send(const InterceptorMessage& msg) {
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor> interceptor) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(),
PADDLE_ENFORCE_EQ(iter,
interceptor_idx_to_interceptor_.end(),
platform::errors::AlreadyExists(
"The interceptor id %lld has already been created! "
"The interceptor id should be unique.",
......@@ -267,19 +277,22 @@ void Carrier::CreateInterceptors() {
TaskNode* task_node = item.second;
PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
task_node->run_at_offset(),
task_node->run_per_steps(),
platform::errors::InvalidArgument(
"Interceptor's run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld",
task_node->run_at_offset(), task_node->run_per_steps()));
task_node->run_at_offset(),
task_node->run_per_steps()));
std::unique_ptr<Interceptor> interceptor;
PADDLE_ENFORCE_NE(task_node->type().empty(), true,
PADDLE_ENFORCE_NE(task_node->type().empty(),
true,
platform::errors::NotFound(
"Cannot found type for task node with id %lld",
task_node->task_id()));
interceptor = InterceptorFactory::Create(task_node->type(), interceptor_id,
task_node);
interceptor = InterceptorFactory::Create(
task_node->type(), interceptor_id, task_node);
interceptor->SetPlace(place_);
interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_);
......
......@@ -56,12 +56,15 @@ class Carrier final {
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
const framework::ProgramDesc& program, framework::Scope* scope,
int64_t num_micro_batches, const platform::Place& place,
const framework::ProgramDesc& program,
framework::Scope* scope,
int64_t num_micro_batches,
const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars = {});
void CopyParameters(
int microbatch_id, const framework::ProgramDesc& program,
int microbatch_id,
const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars);
void Release();
......
......@@ -43,7 +43,8 @@ void ComputeInterceptor::PrepareDeps() {
// source compute node, should we add a new SourceInterceptor?
if (upstream.empty()) {
is_source_ = true;
PADDLE_ENFORCE_GT(node_->max_run_times(), 0,
PADDLE_ENFORCE_GT(node_->max_run_times(),
0,
platform::errors::InvalidArgument(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld",
......@@ -60,7 +61,8 @@ void ComputeInterceptor::PrepareDeps() {
void ComputeInterceptor::IncreaseReady(int64_t up_id) {
auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, in_readys_.end(),
PADDLE_ENFORCE_NE(it,
in_readys_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id));
......@@ -73,26 +75,32 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
auto max_ready_size = it->second.first;
auto ready_size = it->second.second;
ready_size += 1;
PADDLE_ENFORCE_LE(ready_size, max_ready_size,
PADDLE_ENFORCE_LE(ready_size,
max_ready_size,
platform::errors::OutOfRange(
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld",
up_id, ready_size, max_ready_size));
up_id,
ready_size,
max_ready_size));
it->second.second = ready_size;
}
void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
auto it = out_buffs_.find(down_id);
PADDLE_ENFORCE_NE(it, out_buffs_.end(),
PADDLE_ENFORCE_NE(it,
out_buffs_.end(),
platform::errors::NotFound(
"Cannot find downstream=%lld in out_buffs.", down_id));
auto used_size = it->second.second;
used_size -= 1;
PADDLE_ENFORCE_GE(
used_size, 0,
used_size,
0,
platform::errors::OutOfRange(
"downstream=%lld used buff size must >= 0, but now equal %lld",
down_id, used_size));
down_id,
used_size));
it->second.second = used_size;
}
......@@ -130,11 +138,14 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
auto used_size = outs.second.second;
used_size += 1;
PADDLE_ENFORCE_LE(
used_size, max_buff_size,
used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id, used_size, max_buff_size));
down_id,
used_size,
max_buff_size));
outs.second.second = used_size;
InterceptorMessage ready_msg;
......@@ -152,9 +163,11 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
auto ready_size = ins.second.second;
ready_size -= 1;
PADDLE_ENFORCE_GE(
ready_size, 0,
ready_size,
0,
platform::errors::OutOfRange(
"upstream=%lld ready_size must >= 0, but now got %lld", up_id,
"upstream=%lld ready_size must >= 0, but now got %lld",
up_id,
ready_size));
ins.second.second = ready_size;
......@@ -176,8 +189,10 @@ void ComputeInterceptor::RunOps() {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
if (gc_) {
framework::DeleteUnusedTensors(
*microbatch_scopes_[step_ % node_->max_run_times()], op,
node_->unused_vars(), gc_.get());
*microbatch_scopes_[step_ % node_->max_run_times()],
op,
node_->unused_vars(),
gc_.get());
}
}
}
......@@ -210,11 +225,13 @@ void ComputeInterceptor::ReceivedStop(int64_t up_id) {
if (is_source_ && up_id == -1) return;
auto it = in_stops_.find(up_id);
PADDLE_ENFORCE_NE(it, in_stops_.end(),
PADDLE_ENFORCE_NE(it,
in_stops_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_stops.", up_id));
PADDLE_ENFORCE_EQ(
it->second, false,
it->second,
false,
platform::errors::AlreadyExists("Already received stop from %lld, stop "
"cannot be send more than once."));
it->second = true;
......
......@@ -71,7 +71,8 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
if (platform::is_cpu_place(place)) {
VLOG(3) << "Loading data for CPU.";
std::memcpy(static_cast<void *>(input_tensor_ptr), input_data.data.data(),
std::memcpy(static_cast<void *>(input_tensor_ptr),
input_data.data.data(),
input_data.data.length());
} else if (platform::is_gpu_place(place)) {
VLOG(3) << "Loading data for GPU.";
......@@ -80,9 +81,12 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
auto *dev_ctx =
dynamic_cast<const platform::CUDADeviceContext *>(pool.Get(place));
auto gpu_place = place;
memory::Copy(gpu_place, static_cast<void *>(input_tensor_ptr),
platform::CPUPlace(), input_data.data.data(),
input_data.data.length(), dev_ctx->stream());
memory::Copy(gpu_place,
static_cast<void *>(input_tensor_ptr),
platform::CPUPlace(),
input_data.data.data(),
input_data.data.length(),
dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Paddle wasn't compiled with CUDA, but place is GPU."));
......@@ -139,15 +143,17 @@ class DistModelTimer {
bool DistModel::Init() {
carrier_id_ = "inference";
bool init_method = (!config_.model_dir.empty() || config_.program_desc);
PADDLE_ENFORCE_EQ(init_method, true,
PADDLE_ENFORCE_EQ(init_method,
true,
platform::errors::InvalidArgument(
"One of model dir or program desc must be provided to "
"dist model inference."));
if (config_.program_desc) {
PADDLE_ENFORCE_NOT_NULL(
config_.scope, platform::errors::InvalidArgument(
"Scope must be provided to dist model inference if "
"program desc has been provided."));
config_.scope,
platform::errors::InvalidArgument(
"Scope must be provided to dist model inference if "
"program desc has been provided."));
}
if (!PreparePlace()) {
return false;
......@@ -217,8 +223,12 @@ bool DistModel::CommInit() {
}
peer_endpoints.emplace_back(config_.trainer_endpoints[rank]);
}
InsertCommOp(var_name_base + std::to_string(order), ranks_in_group,
rank_in_group, peer_endpoints, comm_init_block, ring_id);
InsertCommOp(var_name_base + std::to_string(order),
ranks_in_group,
rank_in_group,
peer_endpoints,
comm_init_block,
ring_id);
order += 1;
}
framework::NaiveExecutor e(place_);
......@@ -229,9 +239,12 @@ bool DistModel::CommInit() {
return true;
}
void DistModel::InsertCommOp(std::string tmp_var_name, int nranks, int rank,
void DistModel::InsertCommOp(std::string tmp_var_name,
int nranks,
int rank,
const std::vector<std::string> &peer_endpoints,
framework::BlockDesc *block, int ring_id) {
framework::BlockDesc *block,
int ring_id) {
/*
* tmp_var_name: the var name for var comm_id
* nranks: number of total ranks
......@@ -297,7 +310,8 @@ bool DistModel::PrepareProgram() {
bool DistModel::LoadProgram() {
VLOG(3) << "Loading program from " << config_.model_dir;
PADDLE_ENFORCE_NE(
config_.model_dir, "",
config_.model_dir,
"",
platform::errors::InvalidArgument("Model dir must be provided."));
std::string model_path = config_.model_dir + ".pdmodel";
framework::proto::ProgramDesc program_proto;
......@@ -305,7 +319,8 @@ bool DistModel::LoadProgram() {
// Read binary
std::ifstream fin(model_path, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin.is_open()), true,
static_cast<bool>(fin.is_open()),
true,
platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file is normal.",
model_path));
......@@ -387,8 +402,13 @@ bool DistModel::PrepareFleetExe() {
id_to_rank.insert({i, i});
}
fleet_exe.reset(new FleetExecutor(executor_desc_));
fleet_exe->Init(carrier_id_, *(program_.get()), scope_.get(), place_, 1,
{task_node_.get()}, id_to_rank);
fleet_exe->Init(carrier_id_,
*(program_.get()),
scope_.get(),
place_,
1,
{task_node_.get()},
id_to_rank);
return true;
}
......@@ -490,9 +510,11 @@ bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data,
int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col"));
VLOG(3) << "Fetching data for [" << idx_to_fetches_[idx] << "]";
PADDLE_ENFORCE_EQ(
static_cast<size_t>(idx), i,
static_cast<size_t>(idx),
i,
platform::errors::InvalidArgument(
"Fetch op's col attr(%d) should be equal to the index(%d)", idx,
"Fetch op's col attr(%d) should be equal to the index(%d)",
idx,
i));
framework::FetchType &fetch_var =
framework::GetFetchVariable(*scope, "fetch", idx);
......
......@@ -72,9 +72,12 @@ class DistModel {
bool CommInit();
bool PrepareFeedAndFetch();
bool PrepareFleetExe();
void InsertCommOp(std::string tmp_var_name, int nranks, int rank,
void InsertCommOp(std::string tmp_var_name,
int nranks,
int rank,
const std::vector<std::string>& peer_endpoints,
framework::BlockDesc* block, int ring_id);
framework::BlockDesc* block,
int ring_id);
bool FeedData(const std::vector<DistModelTensor>& input_data,
framework::Scope* scope);
bool FetchResults(std::vector<DistModelTensor>* output_data,
......
......@@ -28,7 +28,8 @@ void DistModelDataBuf::Reset(void* data, size_t length) {
void DistModelDataBuf::Free() {
if (memory_owned_ && data_) {
PADDLE_ENFORCE_GT(length_, 0UL,
PADDLE_ENFORCE_GT(length_,
0UL,
platform::errors::PreconditionNotMet(
"Error occurred when deconstruct DistModelDataBuf: "
"it contains no data!"));
......
......@@ -30,8 +30,9 @@ namespace distributed {
FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
"Error occurs while parsing string to proto"));
PADDLE_ENFORCE(parse_flag,
platform::errors::PreconditionNotMet(
"Error occurs while parsing string to proto"));
// Message bus will be created and inited only once
GlobalVal<MessageBus>::Create();
InitMessageBus();
......@@ -51,12 +52,16 @@ FleetExecutor::~FleetExecutor() {
}
void FleetExecutor::Init(
const std::string& carrier_id, const framework::ProgramDesc& program_desc,
framework::Scope* scope, const platform::Place& place,
int64_t num_micro_batches, const std::vector<TaskNode*>& task_nodes,
const std::string& carrier_id,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars) {
PADDLE_ENFORCE_GT(task_nodes.size(), 0,
PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node"));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
......@@ -116,18 +121,30 @@ void FleetExecutor::Init(
carrier_ids_.insert(carrier_id);
// Set current running carrier
GlobalVal<std::string>::Set(new std::string(carrier_id));
InitCarrier(carrier, scope, place, num_micro_batches, program_desc,
InitCarrier(carrier,
scope,
place,
num_micro_batches,
program_desc,
inference_root_scope_vars);
GlobalVal<MessageBus>::Get()->Barrier();
}
void FleetExecutor::InitCarrier(
Carrier* carrier, framework::Scope* scope, const platform::Place& place,
int64_t num_micro_batches, const framework::ProgramDesc& program_desc,
Carrier* carrier,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const framework::ProgramDesc& program_desc,
const std::vector<std::string>& inference_root_scope_vars) {
carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_node(), program_desc, scope,
num_micro_batches, place, inference_root_scope_vars);
carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_node(),
program_desc,
scope,
num_micro_batches,
place,
inference_root_scope_vars);
}
void FleetExecutor::InitMessageBus() {
......@@ -148,11 +165,13 @@ void FleetExecutor::InitMessageBus() {
}
if (addr == "") {
PADDLE_ENFORCE_EQ(
rank_to_addr.size(), 1,
rank_to_addr.size(),
1,
platform::errors::NotFound("Empty address is not valid for "
"paddle.distributed.launch method."));
PADDLE_ENFORCE_EQ(
cur_rank, 0,
cur_rank,
0,
platform::errors::NotFound("Address is empty but cur rank is not 0."));
}
VLOG(3) << "Current rank is " << cur_rank << " and the ip_port is "
......
......@@ -39,8 +39,10 @@ class FleetExecutor final {
explicit FleetExecutor(const FleetExecutorDesc& exe_desc);
~FleetExecutor();
void Init(const std::string& carrier_id,
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place, int64_t num_micro_batches,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars = {});
......@@ -50,8 +52,11 @@ class FleetExecutor final {
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
void InitMessageBus();
void InitCarrier(
Carrier* carrier, framework::Scope* scope, const platform::Place& place,
int64_t num_micro_batches, const framework::ProgramDesc& program_desc,
Carrier* carrier,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const framework::ProgramDesc& program_desc,
const std::vector<std::string>& inference_root_scope_vars = {});
FleetExecutorDesc exe_desc_;
std::shared_ptr<RuntimeGraph> runtime_graph_;
......
......@@ -31,7 +31,8 @@ class GlobalVal final {
template <typename... Args>
static T* Create(Args&&... args) {
auto* ptr = GetPPtr();
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
PADDLE_ENFORCE_EQ(ptr->get(),
nullptr,
platform::errors::AlreadyExists(
"This value is already a global value."));
T* item = new T(std::forward<Args>(args)...);
......@@ -65,7 +66,8 @@ class GlobalMap final {
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
PADDLE_ENFORCE_EQ(ptr->get(),
nullptr,
platform::errors::AlreadyExists(
"This value has already in global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...);
......@@ -86,14 +88,16 @@ class ThreadSafeGlobalMap final {
static ValueT* Get(KeyT id) {
ValueT* item = GetPPtr(id)->get();
PADDLE_ENFORCE_NOT_NULL(
item, platform::errors::NotFound(
"This value is not in thread safe global map."));
item,
platform::errors::NotFound(
"This value is not in thread safe global map."));
return item;
}
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
PADDLE_ENFORCE_EQ(ptr->get(),
nullptr,
platform::errors::AlreadyExists(
"This value has already in thread safe global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...);
......
......@@ -35,8 +35,9 @@ Interceptor::~Interceptor() {
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
void Interceptor::Handle(const InterceptorMessage& msg) {
PADDLE_ENFORCE_NOT_NULL(handle_, platform::errors::PreconditionNotMet(
"Message handle is not registered."));
PADDLE_ENFORCE_NOT_NULL(handle_,
platform::errors::PreconditionNotMet(
"Message handle is not registered."));
handle_(msg);
}
......@@ -46,7 +47,8 @@ void Interceptor::LoopOnce() {
std::lock_guard<std::mutex> lock(mutex_);
messages_.swap(tmp_messages);
}
PADDLE_ENFORCE_EQ(tmp_messages.empty(), false,
PADDLE_ENFORCE_EQ(tmp_messages.empty(),
false,
platform::errors::PreconditionNotMet(
"tmp_messages must not empty in task loop"));
......@@ -61,8 +63,9 @@ void Interceptor::LoopOnce() {
}
void Interceptor::StopCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
"Carrier is not registered."));
PADDLE_ENFORCE_NOT_NULL(
carrier_,
platform::errors::PreconditionNotMet("Carrier is not registered."));
carrier_->WakeUp();
}
......@@ -84,8 +87,9 @@ void Interceptor::EnqueueRemoteInterceptorMessage(
}
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
"Carrier is not registered."));
PADDLE_ENFORCE_NOT_NULL(
carrier_,
platform::errors::PreconditionNotMet("Carrier is not registered."));
msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id);
return carrier_->Send(msg);
......@@ -102,7 +106,8 @@ std::unique_ptr<Interceptor> InterceptorFactory::Create(const std::string& type,
auto& interceptor_map = GetInterceptorMap();
auto iter = interceptor_map.find(type);
PADDLE_ENFORCE_NE(
iter, interceptor_map.end(),
iter,
interceptor_map.end(),
platform::errors::NotFound("interceptor %s is not register", type));
return iter->second(id, node);
}
......
......@@ -129,7 +129,8 @@ class InterceptorFactory {
static void Register(const std::string& type, CreateInterceptorFunc func);
static std::unique_ptr<Interceptor> Create(const std::string& type,
int64_t id, TaskNode* node);
int64_t id,
TaskNode* node);
};
template <typename InterceptorClass>
......
......@@ -27,10 +27,12 @@ namespace paddle {
namespace distributed {
void MessageBus::Init(
int64_t rank, const std::unordered_map<int64_t, std::string>& rank_to_addr,
int64_t rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr) {
PADDLE_ENFORCE_EQ(
is_init_, false,
is_init_,
false,
platform::errors::AlreadyExists("MessageBus is already init."));
rank_ = rank;
is_init_ = true;
......@@ -39,12 +41,14 @@ void MessageBus::Init(
if (addr_ != "") {
const auto& addr = GetAddr(rank_);
PADDLE_ENFORCE_EQ(addr, addr_,
PADDLE_ENFORCE_EQ(addr,
addr_,
platform::errors::Fatal(
"The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error.",
addr, addr_));
addr,
addr_));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
......@@ -77,7 +81,8 @@ MessageBus::~MessageBus() {
const std::string& MessageBus::GetAddr(int64_t rank) const {
PADDLE_ENFORCE_NE(
rank_to_addr_.find(rank), rank_to_addr_.end(),
rank_to_addr_.find(rank),
rank_to_addr_.end(),
platform::errors::NotFound("Cannot find addr rank id %lld.", rank));
return rank_to_addr_.at(rank);
}
......@@ -85,7 +90,8 @@ const std::string& MessageBus::GetAddr(int64_t rank) const {
bool MessageBus::Send(int64_t dst_rank,
const InterceptorMessage& interceptor_message) {
PADDLE_ENFORCE_EQ(
IsInit(), true,
IsInit(),
true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized."));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
......@@ -176,7 +182,8 @@ void MessageBus::ListenPort() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0,
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE),
0,
platform::errors::Unavailable("Message bus: init brpc service error."));
// start the server
......@@ -215,7 +222,8 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
options.timeout_ms = 1000;
options.max_retry = 5;
PADDLE_ENFORCE_EQ(
channel.Init(dst_addr_for_brpc, &options), 0,
channel.Init(dst_addr_for_brpc, &options),
0,
platform::errors::Unavailable("Message bus: init brpc channel error."));
MessageService_Stub stub(&channel);
InterceptorResponse response;
......@@ -224,8 +232,8 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
if (interceptor_message.ctrl_message()) {
stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL);
} else {
stub.ReceiveInterceptorMessage(&ctrl, &interceptor_message, &response,
NULL);
stub.ReceiveInterceptorMessage(
&ctrl, &interceptor_message, &response, NULL);
}
if (!ctrl.Failed()) {
if (response.rst()) {
......
......@@ -23,7 +23,8 @@ namespace distributed {
void MessageServiceImpl::ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Message Service receives a message from interceptor "
......@@ -35,7 +36,8 @@ void MessageServiceImpl::ReceiveInterceptorMessage(
void MessageServiceImpl::IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Barrier Service receives a message from rank "
......
......@@ -26,11 +26,13 @@ class MessageServiceImpl : public MessageService {
virtual ~MessageServiceImpl() {}
virtual void ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done);
virtual void IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done);
};
......
......@@ -27,7 +27,8 @@ TaskLoop* TaskLoop::GetTaskLoopOfCurrentThread() { return thread_local_loop_; }
TaskLoop::TaskLoop()
: looping_(false), quit_(false), thread_id_(std::this_thread::get_id()) {
PADDLE_ENFORCE_EQ(
thread_local_loop_, nullptr,
thread_local_loop_,
nullptr,
platform::errors::AlreadyExists("Another TaskLoop is already init."));
thread_local_loop_ = this;
}
......@@ -35,7 +36,8 @@ TaskLoop::TaskLoop()
TaskLoop::~TaskLoop() { thread_local_loop_ = nullptr; }
void TaskLoop::Loop() {
PADDLE_ENFORCE_EQ(looping_, false,
PADDLE_ENFORCE_EQ(looping_,
false,
platform::errors::PreconditionNotMet(
"Loop can only execute in one loop thread"));
AssertInLoopThread();
......@@ -75,7 +77,8 @@ void TaskLoop::WakeUp() {
void TaskLoop::AbortNotInLoopThread() {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"This TaskLoop was created in thread %d, but current thread is %d",
thread_id_, std::this_thread::get_id()));
thread_id_,
std::this_thread::get_id()));
}
} // namespace distributed
......
......@@ -32,7 +32,8 @@ TaskLoopThread::~TaskLoopThread() {
TaskLoop* TaskLoopThread::StartLoop() {
PADDLE_ENFORCE_EQ(
start_, false,
start_,
false,
platform::errors::PreconditionNotMet("thread is already running."));
start_ = true;
thread_ = std::thread([this]() { Loop(); });
......
......@@ -31,10 +31,12 @@ TaskLoopThreadPool::~TaskLoopThreadPool() = default;
void TaskLoopThreadPool::Start() {
PADDLE_ENFORCE_EQ(
start_, false,
start_,
false,
platform::errors::PreconditionNotMet("thread pool is already start."));
PADDLE_ENFORCE_GT(
thread_num_, 0,
thread_num_,
0,
platform::errors::InvalidArgument(
"thread num must greater than 0, but now is %d", thread_num_));
......@@ -47,21 +49,26 @@ void TaskLoopThreadPool::Start() {
TaskLoop* TaskLoopThreadPool::GetLoop(int tid) {
PADDLE_ENFORCE_EQ(
start_, true,
start_,
true,
platform::errors::PreconditionNotMet("thread pool must start first."));
PADDLE_ENFORCE_GE(
tid, 0,
tid,
0,
platform::errors::OutOfRange("tid must >= 0, but now is %d", tid));
PADDLE_ENFORCE_LT(tid, thread_num_,
PADDLE_ENFORCE_LT(tid,
thread_num_,
platform::errors::OutOfRange(
"tid must < thread_num, but now tid=%d thread_num=%d",
tid, thread_num_));
tid,
thread_num_));
return loops_[tid];
}
std::vector<TaskLoop*> TaskLoopThreadPool::GetAllLoops() {
PADDLE_ENFORCE_EQ(
start_, true,
start_,
true,
platform::errors::PreconditionNotMet("thread pool must start first."));
return loops_;
}
......
......@@ -24,8 +24,10 @@ namespace {
using OperatorBase = TaskNode::OperatorBase;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
int64_t max_run_times, int64_t max_slot_nums)
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
max_run_times_(max_run_times),
......@@ -80,7 +82,9 @@ TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times)
TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: role_(role),
rank_(rank),
......@@ -101,7 +105,9 @@ TaskNode::TaskNode(int32_t role,
TaskNode::TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: ops_(ops),
role_(role),
......@@ -110,8 +116,11 @@ TaskNode::TaskNode(int32_t role,
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
TaskNode::TaskNode(int32_t role, int64_t rank, int64_t task_id,
int64_t max_run_times, int64_t max_slot_nums)
TaskNode::TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: role_(role),
rank_(rank),
task_id_(task_id),
......@@ -139,14 +148,16 @@ std::string TaskNode::DebugString() const {
}
void TaskNode::SetRunPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(value, 1,
PADDLE_ENFORCE_GE(value,
1,
platform::errors::InvalidArgument(
"run_per_steps must >= 1, but received %ld", value));
run_per_steps_ = value;
}
void TaskNode::SetRunAtOffset(int64_t value) {
PADDLE_ENFORCE_GE(value, 0,
PADDLE_ENFORCE_GE(value,
0,
platform::errors::InvalidArgument(
"run_at_offset must >= 0, but received %ld", value));
run_at_offset_ = value;
......@@ -154,7 +165,8 @@ void TaskNode::SetRunAtOffset(int64_t value) {
void TaskNode::SetReplyUpPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value, 1,
value,
1,
platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but received %ld", value));
reply_up_per_steps_ = value;
......@@ -162,7 +174,8 @@ void TaskNode::SetReplyUpPerSteps(int64_t value) {
void TaskNode::SetSendDownPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value, 1,
value,
1,
platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but received %ld", value));
send_down_per_steps_ = value;
......
......@@ -33,16 +33,27 @@ class TaskNode final {
public:
using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times,
TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role, const std::vector<framework::OpDesc*>& op_descs,
int64_t rank, int64_t task_id, int64_t max_run_times,
TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role, const std::vector<framework::OperatorBase*>& ops,
int64_t rank, int64_t task_id, int64_t max_run_times,
TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
int64_t max_run_times, int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
~TaskNode() = default;
......
......@@ -40,12 +40,13 @@ std::vector<framework::OperatorBase*> GetOps() {
attrs["shape"] = phi::vectorize<int>({2, 3});
attrs["value"] = 1.0f;
auto zero_op = framework::OpRegistry::CreateOp("fill_constant", {},
{{"Out", {"x"}}}, attrs);
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", {}, {{"Out", {"x"}}}, attrs);
auto op = framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {"x"}}, {"Y", {"x"}}}, {{"Out", {"out"}}},
framework::AttributeMap());
auto op = framework::OpRegistry::CreateOp("elementwise_add",
{{"X", {"x"}}, {"Y", {"x"}}},
{{"Out", {"out"}}},
framework::AttributeMap());
// NOTE: don't delete
return {zero_op.release(), op.release()};
......
......@@ -54,14 +54,15 @@ TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{SOURCE_ID, 0},
{0, 0},
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{SINK_ID, 0}});
carrier->Init(0,
{{SOURCE_ID, 0},
{0, 0},
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
......
......@@ -27,7 +27,8 @@ namespace distributed {
int64_t GetBuffSize(
const std::map<std::pair<TaskNode*, TaskNode*>, int64_t> buffs,
TaskNode* from, TaskNode* to) {
TaskNode* from,
TaskNode* to) {
if (buffs.find({from, to}) != buffs.end()) {
return buffs.at({from, to});
}
......
......@@ -21,7 +21,8 @@ namespace distributed {
std::vector<std::vector<uint64_t>> LayerWiseSampler::sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids, bool with_hierarchy) {
const std::vector<uint64_t>& target_ids,
bool with_hierarchy) {
auto input_num = target_ids.size();
auto user_feature_num = user_inputs[0].size();
std::vector<std::vector<uint64_t>> outputs(
......
......@@ -38,7 +38,8 @@ class IndexSampler {
virtual void init_layerwise_conf(
const std::vector<uint16_t>& layer_sample_counts,
uint16_t start_sample_layer = 1, uint16_t seed = 0) {}
uint16_t start_sample_layer = 1,
uint16_t seed = 0) {}
virtual void init_beamsearch_conf(const int64_t k) {}
virtual std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
......@@ -65,15 +66,18 @@ class LayerWiseSampler : public IndexSampler {
start_sample_layer_ = start_sample_layer;
PADDLE_ENFORCE_GT(
start_sample_layer_, 0,
start_sample_layer_,
0,
paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should greater than 0.",
start_sample_layer_));
PADDLE_ENFORCE_LT(start_sample_layer_, tree_->Height(),
PADDLE_ENFORCE_LT(start_sample_layer_,
tree_->Height(),
paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should less than "
"max_layer, which is [%d].",
start_sample_layer_, tree_->Height()));
start_sample_layer_,
tree_->Height()));
size_t i = 0;
layer_counts_sum_ = 0;
......@@ -113,7 +117,8 @@ class LayerWiseSampler : public IndexSampler {
}
std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids, bool with_hierarchy) override;
const std::vector<uint64_t>& target_ids,
bool with_hierarchy) override;
void sample_from_dataset(
const uint16_t sample_slot,
......
......@@ -29,7 +29,8 @@ int TreeIndex::Load(const std::string filename) {
int err_no;
auto fp = paddle::framework::fs_open_read(filename, &err_no, "");
PADDLE_ENFORCE_NE(
fp, nullptr,
fp,
nullptr,
platform::errors::InvalidArgument(
"Open file %s failed. Please check whether the file exists.",
filename));
......@@ -46,17 +47,21 @@ int TreeIndex::Load(const std::string filename) {
size_t read_num =
fread(const_cast<char*>(content.data()), 1, num, fp.get());
PADDLE_ENFORCE_EQ(
read_num, static_cast<size_t>(num),
read_num,
static_cast<size_t>(num),
platform::errors::InvalidArgument(
"Read from file: %s failed. Valid Format is "
"an integer representing the length of the following string, "
"and the string itself.We got an iteger[% d], "
"but the following string's length is [%d].",
filename, num, read_num));
filename,
num,
read_num));
KVItem item;
PADDLE_ENFORCE_EQ(
item.ParseFromString(content), true,
item.ParseFromString(content),
true,
platform::errors::InvalidArgument("Parse from file: %s failed. It's "
"content can't be parsed by KVItem.",
filename));
......@@ -168,7 +173,8 @@ std::vector<uint64_t> TreeIndex::GetChildrenCodes(uint64_t ancestor,
std::vector<uint64_t> TreeIndex::GetTravelCodes(uint64_t id, int start_level) {
std::vector<uint64_t> res;
PADDLE_ENFORCE_NE(id_codes_map_.find(id), id_codes_map_.end(),
PADDLE_ENFORCE_NE(id_codes_map_.find(id),
id_codes_map_.end(),
paddle::platform::errors::InvalidArgument(
"id = %d doesn't exist in Tree.", id));
auto code = id_codes_map_.at(id);
......
......@@ -76,7 +76,8 @@ class IndexWrapper {
void clear_tree() { tree_map.clear(); }
TreePtr get_tree_index(const std::string name) {
PADDLE_ENFORCE_NE(tree_map.find(name), tree_map.end(),
PADDLE_ENFORCE_NE(tree_map.find(name),
tree_map.end(),
paddle::platform::errors::InvalidArgument(
"tree [%s] doesn't exist. Please insert it firstly "
"by API[\' insert_tree_index \'].",
......@@ -91,11 +92,13 @@ class IndexWrapper {
}
TreePtr tree = std::make_shared<TreeIndex>();
int ret = tree->Load(tree_path);
PADDLE_ENFORCE_EQ(ret, 0,
PADDLE_ENFORCE_EQ(ret,
0,
paddle::platform::errors::InvalidArgument(
"Load tree[%s] from path[%s] failed. Please "
"check whether the file exists.",
name, tree_path));
name,
tree_path));
tree_map.insert(std::pair<std::string, TreePtr>{name, tree});
}
......
......@@ -57,7 +57,8 @@ class DownpourPsClientService : public PsService {
return 0;
}
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, PsResponseMessage *response,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
......@@ -162,13 +163,15 @@ class BrpcPsClient : public PSClient {
const std::string threshold) override;
std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
std::future<int32_t> Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
std::future<int32_t> Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Clear() override;
......@@ -182,7 +185,8 @@ class BrpcPsClient : public PSClient {
void FinalizeWorker() override;
virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
virtual std::future<int32_t> PullDense(Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> PushDenseParam(const Region *regions,
......@@ -190,14 +194,18 @@ class BrpcPsClient : public PSClient {
size_t table_id);
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num, size_t table_id);
size_t region_num,
size_t table_id);
void PushDenseTaskConsume();
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, const uint64_t *keys,
size_t num, bool is_training);
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training);
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
const uint64_t *keys,
size_t num,
bool is_training);
virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
......@@ -213,7 +221,8 @@ class BrpcPsClient : public PSClient {
void *done);
virtual std::future<int32_t> Flush();
std::future<int32_t> SendClient2ClientMsg(int msg_type, int to_client_id,
std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string &msg) override;
// for local save sparse
......@@ -221,14 +230,19 @@ class BrpcPsClient : public PSClient {
const std::string &path);
std::future<int32_t> CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
uint32_t table_id,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) override;
std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
std::vector<int> tables,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold);
std::future<int32_t> SaveCache(uint32_t table_id, const std::string &path,
std::future<int32_t> SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) override;
std::future<int32_t> GetCacheThreshold(uint32_t table_id,
......@@ -256,10 +270,12 @@ class BrpcPsClient : public PSClient {
return dense_dim_total / shard_num + 1;
}
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
std::future<int32_t> SendCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> SendSaveCmd(uint32_t table_id, int cmd_id,
std::future<int32_t> SendSaveCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string> &param);
bool _running = false;
......@@ -281,14 +297,19 @@ class BrpcPsClient : public PSClient {
std::thread _print_thread;
int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num,
int table_id,
int shard_idx, // NOLINT
ValueAccessor *accessor);
int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
DownpourBrpcClosure *closure, ValueAccessor *accessor);
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num,
int table_id,
int shard_idx, // NOLINT
DownpourBrpcClosure *closure,
ValueAccessor *accessor);
SparseTaskPool _sparse_task_pool;
......@@ -304,18 +325,23 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) override;
size_t num,
void *done) override;
std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num, void *done,
uint32_t num,
void *done,
int pserver_idx) override;
std::future<int32_t> PushSparseParam(size_t table_id, const uint64_t *keys,
const float **update_values, size_t num,
std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
std::future<int32_t> PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) override;
void PushSparseTaskConsume();
......@@ -324,7 +350,8 @@ class BrpcPsClient : public PSClient {
int32_t StartClientService();
void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data, size_t total_send_data_size,
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure);
float _mae = 0;
float _mse = 0;
......
......@@ -30,11 +30,14 @@ class RpcController;
} // namespace protobuf
} // namespace google
DEFINE_int32(pserver_timeout_ms_s2s, 10000,
DEFINE_int32(pserver_timeout_ms_s2s,
10000,
"pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms_s2s, 10000,
DEFINE_int32(pserver_connect_timeout_ms_s2s,
10000,
"pserver connect server timeout_ms");
DEFINE_string(pserver_connection_type_s2s, "pooled",
DEFINE_string(pserver_connection_type_s2s,
"pooled",
"pserver connection_type[pooled:single]");
namespace paddle {
......@@ -154,12 +157,13 @@ std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
closure->request(0)->set_table_id(0);
closure->request(0)->set_data(msg);
PsService_Stub rpc_stub(_pserver_channels[to_pserver_id].get());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
int32_t BrpcPsServer::ReceiveFromPServer(int msg_type, int pserver_id,
int32_t BrpcPsServer::ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) {
if (msg.length() == 0) {
LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response";
......@@ -289,7 +293,8 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
}
}
int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
int32_t BrpcPsService::PullDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
......@@ -297,7 +302,8 @@ int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response, -1,
response,
-1,
"PsRequestMessage.datas is requeired at least 1 for num of dense");
return 0;
}
......@@ -357,7 +363,8 @@ int32_t BrpcPsService::PushDenseParam(Table *table,
return 0;
}
int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request,
int32_t BrpcPsService::PushDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
......@@ -391,13 +398,15 @@ int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t BrpcPsService::Barrier(Table *table, const PsRequestMessage &request,
int32_t BrpcPsService::Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(response, -1,
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
......@@ -423,7 +432,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table,
return 0;
}
if (request.params_size() < 1) {
set_response_code(response, -1,
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
......@@ -482,7 +492,8 @@ int32_t BrpcPsService::PullGeoParam(Table *table,
return 0;
}
int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
int32_t BrpcPsService::PullSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
......@@ -498,7 +509,8 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
}
if (request.params_size() < 1) {
set_response_code(response, -1,
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
......@@ -533,7 +545,8 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request,
int32_t BrpcPsService::PushSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
......@@ -545,7 +558,8 @@ int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request,
return 0;
}
if (request.params_size() < 1) {
set_response_code(response, -1,
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
......@@ -594,7 +608,8 @@ int32_t BrpcPsService::LoadOneTable(Table *table,
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
response,
-1,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
......@@ -626,7 +641,8 @@ int32_t BrpcPsService::SaveOneTable(Table *table,
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
response,
-1,
"PsRequestMessage.datas is requeired at least 2, path&mode");
return -1;
}
......@@ -667,7 +683,8 @@ int32_t BrpcPsService::SaveCacheTable(Table *table,
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
response,
-1,
"PsRequestMessage.datas is requeired at least 3, path&mode");
return -1;
}
......@@ -676,8 +693,8 @@ int32_t BrpcPsService::SaveCacheTable(Table *table,
// if (_server->_shuffled_ins->size() <= 0) {
// LOG(WARNING) << "shuffled ins size <= 0";
//}
feasign_size = table->SaveCache(request.params(0), request.params(1),
_server->_shuffled_ins);
feasign_size = table->SaveCache(
request.params(0), request.params(1), _server->_shuffled_ins);
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
......@@ -692,7 +709,8 @@ int32_t BrpcPsService::CacheShuffle(Table *table,
// start cache shuffle
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(response, -1,
set_response_code(response,
-1,
"PsRequestMessage.datas is requeired at least 3, "
"path&mode&cache_threshold");
return -1;
......@@ -704,9 +722,10 @@ int32_t BrpcPsService::CacheShuffle(Table *table,
// std::string>>();
// shuffled_ins->set_block_size(80000);
_server->StartS2S();
std::function<std::future<int32_t>(int msg_type, int to_pserver_id,
const std::string &msg)>
send_msg_func = [this](int msg_type, int to_pserver_id,
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, const std::string &msg)>
send_msg_func = [this](int msg_type,
int to_pserver_id,
const std::string &msg) -> std::future<int32_t> {
return this->_server->SendPServer2PServerMsg(msg_type, to_pserver_id, msg);
};
......@@ -721,8 +740,12 @@ int32_t BrpcPsService::CacheShuffle(Table *table,
table_ptrs.push_back(table);
}
table->CacheShuffle(request.params(0), request.params(1), cache_threshold,
send_msg_func, _server->_shuffled_ins, table_ptrs);
table->CacheShuffle(request.params(0),
request.params(1),
cache_threshold,
send_msg_func,
_server->_shuffled_ins,
table_ptrs);
return 0;
}
......@@ -751,7 +774,8 @@ int32_t BrpcPsService::ShrinkTable(Table *table,
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response, -1,
response,
-1,
"PsRequestMessage.datas is requeired at least 1, threshold");
return -1;
}
......@@ -787,7 +811,8 @@ int32_t BrpcPsService::ClearAllTable(Table *table,
return 0;
}
int32_t BrpcPsService::StopServer(Table *table, const PsRequestMessage &request,
int32_t BrpcPsService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto *p_server = _server;
......
......@@ -56,7 +56,8 @@ class BrpcPsServer : public PSServer {
virtual int32_t StartS2S() override;
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override;
virtual int32_t ReceiveFromPServer(int msg_type, int pserver_id,
virtual int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) override;
private:
......@@ -72,7 +73,9 @@ class BrpcPsServer : public PSServer {
class BrpcPsService;
typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
class BrpcPsService : public PsBaseService {
......@@ -86,56 +89,101 @@ class BrpcPsService : public PsBaseService {
private:
int32_t InitializeShardInfo();
int32_t PullDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PushDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PushDenseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PushSparseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PullSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PullGeoParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PushSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t LoadOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t LoadAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t SaveOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t SaveAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t ShrinkTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t ClearOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t ClearAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StartProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PrintTableStat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t CacheShuffle(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table, const PsRequestMessage &request,
int32_t PullDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushDenseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushSparseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PullSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PullGeoParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t SaveOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t SaveAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t ClearOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t ClearAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
......
......@@ -56,8 +56,10 @@ void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx, const framework::Scope* scope,
MultiVarMsg* request, butil::IOBuf* iobuf) {
const platform::DeviceContext& ctx,
const framework::Scope* scope,
MultiVarMsg* request,
butil::IOBuf* iobuf) {
// 1. message_name
request->set_message_name(message_name);
......@@ -87,7 +89,8 @@ void SerializeToMultiVarMsgAndIOBuf(
}
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::distributed::LOD_TENSOR);
......@@ -119,7 +122,10 @@ void SerializeLodTensor(framework::Variable* var,
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(),
platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::SizeOfType(
framework::TransToProtoVarType(tensor->dtype())),
stream);
......@@ -132,7 +138,8 @@ void SerializeLodTensor(framework::Variable* var,
}
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf) {
phi::SelectedRows* slr = var->GetMutable<phi::SelectedRows>();
auto* tensor = slr->mutable_value();
......@@ -164,7 +171,10 @@ void SerializeSelectedRows(framework::Variable* var,
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(),
platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::SizeOfType(
framework::TransToProtoVarType(tensor->dtype())),
stream);
......@@ -204,7 +214,8 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->FindVar(msg.varname());
PADDLE_ENFORCE_NE(var, nullptr,
PADDLE_ENFORCE_NE(var,
nullptr,
platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
......@@ -215,7 +226,8 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
}
}
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
void DeserializeLodTensor(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
......@@ -255,16 +267,20 @@ void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); // NOLINT
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(
place, tensor_data, platform::CPUPlace(), (void*)temp_ptr, // NOLINT
tensor->numel() * framework::DataTypeSize(tensor->dtype()), stream);
memory::Copy(place,
tensor_data,
platform::CPUPlace(),
(void*)temp_ptr, // NOLINT
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream);
delete[] temp_ptr;
#endif
}
}
void DeserializeSelectedRows(
framework::Variable* var, const VarMsg& msg,
framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
......@@ -297,7 +313,10 @@ void DeserializeSelectedRows(
io_buffer_itr.copy_and_forward(temp_ptr, data_len);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(place, tensor_data, platform::CPUPlace(), temp_ptr,
memory::Copy(place,
tensor_data,
platform::CPUPlace(),
temp_ptr,
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream);
delete[] temp_ptr;
......
......@@ -55,15 +55,19 @@ void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx, const framework::Scope* scope,
MultiVarMsg* var_msg, butil::IOBuf* iobuf);
const platform::DeviceContext& ctx,
const framework::Scope* scope,
MultiVarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
const platform::DeviceContext& ctx,
VarMsg* request,
butil::IOBuf* iobuf);
// Deserialize for Server
......@@ -78,11 +82,13 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope);
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
void DeserializeLodTensor(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& iobuf, // NOLINT
const platform::DeviceContext& ctx);
void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
void DeserializeSelectedRows(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& iobuf, // NOLINT
const platform::DeviceContext& ctx);
......
......@@ -88,7 +88,8 @@ int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
}
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope) {
int table_id,
Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvDense",
platform::TracerEventType::Communication,
1);
......@@ -146,7 +147,8 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
}
void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id, const Scope &scope) {
int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendDenseParam",
platform::TracerEventType::Communication,
1);
......@@ -224,13 +226,14 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->PushDenseRawGradient(table_id, data,
dense_data->size(), closure);
auto status = _worker_ptr->PushDenseRawGradient(
table_id, data, dense_data->size(), closure);
status.wait();
return;
}
void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
void Communicator::RpcSendSparseParam(const std::string &varname,
int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparseParam",
platform::TracerEventType::Communication,
......@@ -262,14 +265,17 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->PushSparseParam(table_id, sparse_push_keys.data(),
auto status = _worker_ptr->PushSparseParam(table_id,
sparse_push_keys.data(),
(const float **)push_g_vec.data(),
sparse_push_keys.size(), closure);
sparse_push_keys.size(),
closure);
status.wait();
return;
}
void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
void Communicator::RpcSendSparse(const std::string &var_name,
int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparse",
platform::TracerEventType::Communication,
......@@ -281,7 +287,8 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
auto *send_var = scope.FindVar(var_name);
auto *tensor = send_var->GetMutable<phi::SelectedRows>();
auto dim = tensor->value().dims()[1];
std::transform(tensor->rows().begin(), tensor->rows().end(),
std::transform(tensor->rows().begin(),
tensor->rows().end(),
std::back_inserter(sparse_push_keys),
[&](int64_t id) { return static_cast<uint64_t>(id); });
......@@ -315,14 +322,18 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->PushSparseRawGradient(
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure);
auto status =
_worker_ptr->PushSparseRawGradient(table_id,
sparse_push_keys.data(),
(const float **)push_g_vec.data(),
sparse_push_keys.size(),
closure);
status.wait();
return;
}
void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
void Communicator::RpcRecvSparse(const std::string &varname,
int table_id,
Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvSparse",
platform::TracerEventType::Communication,
......@@ -342,9 +353,11 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
bool training = true;
auto status = _worker_ptr->PullSparseParam(
(float **)push_g_vec.data(), table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size(), training);
auto status = _worker_ptr->PullSparseParam((float **)push_g_vec.data(),
table_id, // NOLINT
sparse_push_keys.data(),
sparse_push_keys.size(),
training);
status.wait();
return;
}
......@@ -389,7 +402,8 @@ void Communicator::RpcProfilerControl() {
}
}
void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
void Communicator::SendGlobalStep(const CommContext &ctx,
int batches,
Scope *send_scope) {
if (batches == 0) {
return;
......@@ -522,7 +536,8 @@ void AsyncCommunicator::SendByCommunicator() {
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_);
......@@ -569,9 +584,13 @@ void AsyncCommunicator::MainThread() {
}
void AsyncCommunicator::PullSparseToTensorSync(
const uint64_t table_id, int fea_dim, uint64_t padding_id,
platform::Place place, bool is_training,
std::vector<const LoDTensor *> *inputs, std::vector<LoDTensor *> *outputs) {
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
bool is_training,
std::vector<const LoDTensor *> *inputs,
std::vector<LoDTensor *> *outputs) {
std::vector<uint64_t> fea_keys;
std::vector<float *> pull_result_ptr;
fea_keys.reserve(MAX_FEASIGN_NUM / 100);
......@@ -598,7 +617,8 @@ void AsyncCommunicator::PullSparseToTensorSync(
}
uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) {
memcpy(output_data + output_len, init_value.data(),
memcpy(output_data + output_len,
init_value.data(),
sizeof(float) * fea_dim);
continue;
}
......@@ -606,9 +626,11 @@ void AsyncCommunicator::PullSparseToTensorSync(
pull_result_ptr.push_back(output_data + output_len);
}
}
auto status =
_worker_ptr->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(),
fea_keys.size(), is_training);
auto status = _worker_ptr->PullSparse(pull_result_ptr.data(),
table_id,
fea_keys.data(),
fea_keys.size(),
is_training);
status.wait();
auto ret = status.get();
if (ret != 0) {
......@@ -618,9 +640,13 @@ void AsyncCommunicator::PullSparseToTensorSync(
}
void AsyncCommunicator::PushSparseFromTensorAsync(
const uint64_t table_id, int fea_dim, uint64_t padding_id,
platform::Place place, std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows, const framework::LoDTensor *clks,
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows,
const framework::LoDTensor *clks,
std::vector<framework::LoDTensor *> *outputs) {
int batch_size = -1;
bool batch_size_consist = true;
......@@ -735,10 +761,12 @@ void AsyncCommunicator::PushSparseFromTensorAsync(
}
PADDLE_ENFORCE_EQ(
this->Check(table_id), true,
this->Check(table_id),
true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
auto status = _worker_ptr->PushSparse(table_id, push_keys.data(),
auto status = _worker_ptr->PushSparse(table_id,
push_keys.data(),
(const float **)push_g_vec.data(),
push_keys.size());
}
......@@ -831,7 +859,8 @@ void AsyncCommunicator::Stop() {
bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) {
PADDLE_ENFORCE_EQ(
var_tables.size(), 1,
var_tables.size(),
1,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
......@@ -948,7 +977,8 @@ void HalfAsyncCommunicator::SendByCommunicator() {
if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_);
......@@ -1003,7 +1033,8 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
auto *var = scope.FindVar(table_name);
PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(), true,
PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(),
true,
platform::errors::InvalidArgument(
"Only need to send Sparse Grad in Geo mode."));
auto &rows = var->Get<phi::SelectedRows>().rows();
......@@ -1037,7 +1068,8 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
recv_scope_ = std::move(recv_scope);
PADDLE_ENFORCE_GT(
send_varname_to_ctx.size(), 0,
send_varname_to_ctx.size(),
0,
platform::errors::InvalidArgument("send var contexts can not be zero"));
for (auto &iter : send_varname_to_ctx_) {
......@@ -1048,14 +1080,16 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
}
auto &varnames = ctx.origin_varnames;
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
for (auto &splited_var : ctx.splited_varnames) {
parallel_task_nums_ += 1;
sparse_id_queues_.insert(
std::pair<std::string, paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>(
std::pair<std::string,
paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>(
splited_var,
paddle::framework::MakeChannel<
std::shared_ptr<std::vector<int64_t>>>(send_queue_size_)));
......@@ -1138,10 +1172,12 @@ void GeoCommunicator::SendDense(const CommContext &send_ctx) {
auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_timestamp = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true,
PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
......@@ -1154,14 +1190,18 @@ void GeoCommunicator::SendDense(const CommContext &send_ctx) {
t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
blas.VSUB(t_latest.numel(), t_latest.data<float>(),
t_timestamp->data<float>(), t_delta->data<float>());
blas.VSUB(t_latest.numel(),
t_latest.data<float>(),
t_timestamp->data<float>(),
t_delta->data<float>());
float coefficient = 1.0 / static_cast<float>(trainers_);
blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
blas.VADD(t_latest.numel(), t_timestamp->data<float>(),
t_delta->data<float>(), t_timestamp->data<float>());
blas.VADD(t_latest.numel(),
t_timestamp->data<float>(),
t_delta->data<float>(),
t_timestamp->data<float>());
}
RpcSendDense(send_ctx, *delta_scope_);
VLOG(1) << "Finish Send Dense " << var_names[0] << ", table_id: " << table_id;
......@@ -1194,12 +1234,16 @@ void GeoCommunicator::RecvDense(const CommContext &send_ctx) {
t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
blas.VSUB(t_latest->numel(), t_pserver.data<float>(), t_old->data<float>(),
blas.VSUB(t_latest->numel(),
t_pserver.data<float>(),
t_old->data<float>(),
t_delta->data<float>());
blas.VADD(t_latest->numel(), t_latest->data<float>(),
t_delta->data<float>(), t_latest->data<float>());
blas.VCOPY(t_latest->numel(), t_pserver.data<float>(),
t_old->data<float>());
blas.VADD(t_latest->numel(),
t_latest->data<float>(),
t_delta->data<float>(),
t_latest->data<float>());
blas.VCOPY(
t_latest->numel(), t_pserver.data<float>(), t_old->data<float>());
}
VLOG(1) << "Finish Recv Dense " << varnames[0] << ", table_id: " << table_id;
return;
......@@ -1260,7 +1304,8 @@ std::vector<int64_t> GeoCommunicator::MergeSparseIds(
}
void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, int table_id,
std::vector<int64_t> &sparse_ids,
int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->SendSparse",
platform::TracerEventType::Communication,
......@@ -1276,10 +1321,12 @@ void GeoCommunicator::SendSparse(const std::string &varname,
auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_old = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_old->IsInitialized(), true,
PADDLE_ENFORCE_EQ(var_old->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
......@@ -1303,11 +1350,13 @@ void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<float *> push_g_vec;
for (auto j = 0; j < static_cast<int>(sparse_ids.size()); ++j) {
blas.VSUB(dims1, t_latest.data<float>() + sparse_ids[j] * dims1,
blas.VSUB(dims1,
t_latest.data<float>() + sparse_ids[j] * dims1,
t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1);
blas.SCAL(dims1, coefficient, t_value + j * dims1);
blas.VADD(dims1, t_old->data<float>() + sparse_ids[j] * dims1,
blas.VADD(dims1,
t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1,
t_old->data<float>() + sparse_ids[j] * dims1);
push_g_vec.push_back(t_value + j * dims1);
......@@ -1328,8 +1377,12 @@ void GeoCommunicator::SendSparse(const std::string &varname,
--_async_call_num;
});
auto status = _worker_ptr->PushSparseRawGradientPartial(
table_id, (const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx);
table_id,
(const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(),
sparse_ids.size(),
closure,
ep_idx);
status.wait();
VLOG(1) << "Finish Send Sparse " << varname
......@@ -1337,7 +1390,8 @@ void GeoCommunicator::SendSparse(const std::string &varname,
return;
}
void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
void GeoCommunicator::RecvSparse(const std::string &varname,
int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->RecvSparse",
platform::TracerEventType::Communication,
......@@ -1375,8 +1429,8 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
float *latest_data = t_latest->data<float>() + keys[j] * dims1;
float *old_data = t_old->data<float>() + keys[j] * dims1;
// pserver - old => delta
blas.VSUB(dims1, values.data() + j * dims1, old_data,
v_delta.data() + j * dims1);
blas.VSUB(
dims1, values.data() + j * dims1, old_data, v_delta.data() + j * dims1);
// latest + delta => latest
blas.VADD(dims1, latest_data, v_delta.data() + j * dims1, latest_data);
// pserver => old
......@@ -1404,7 +1458,8 @@ void GeoCommunicator::MainThread() {
if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
int pserver_num = static_cast<int>(ctx.epmap.size());
......
......@@ -63,7 +63,8 @@ template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0,
PADDLE_ENFORCE_GT(capacity_,
0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
......@@ -149,16 +150,19 @@ class BlockingQueue {
mutable std::mutex mutex_;
};
template <typename T, int MajorType = Eigen::RowMajor,
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope, bool merge_add = true) {
Scope *scope,
bool merge_add = true) {
PADDLE_ENFORCE_NE(
vars.empty(), true,
vars.empty(),
true,
platform::errors::InvalidArgument("vector vars are empty."));
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
......@@ -175,7 +179,8 @@ inline void MergeVars(const std::string &var_name,
for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
var_t.dims(), dims,
var_t.dims(),
dims,
platform::errors::InvalidArgument("vars should have the same dims."));
}
......@@ -207,14 +212,14 @@ inline void MergeVars(const std::string &var_name,
}
paddle::platform::CPUDeviceContext dev_ctx;
if (merge_add) {
paddle::operators::math::scatter::MergeAdd<
paddle::platform::CPUDeviceContext, T>
merge_add;
paddle::operators::math::scatter::
MergeAdd<paddle::platform::CPUDeviceContext, T>
merge_add;
merge_add(dev_ctx, inputs, out_slr);
} else {
paddle::operators::math::scatter::MergeAverage<
paddle::platform::CPUDeviceContext, T>
merge_average;
paddle::operators::math::scatter::
MergeAverage<paddle::platform::CPUDeviceContext, T>
merge_average;
merge_average(dev_ctx, inputs, out_slr);
}
......@@ -254,23 +259,29 @@ class Communicator {
// 1. recv dense param
virtual void RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope);
int table_id,
Scope *scope);
// 2. send dense param
virtual void RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id, const Scope &scope);
int table_id,
const Scope &scope);
// 3. send dense grad
virtual void RpcSendDense(const CommContext &ctx, const Scope &scope);
// 4. send sparse grad
virtual void RpcSendSparse(const std::string &var_name, int table_id,
virtual void RpcSendSparse(const std::string &var_name,
int table_id,
const Scope &scope);
// 5. send sparse param
virtual void RpcSendSparseParam(const std::string &varname, int table_id,
virtual void RpcSendSparseParam(const std::string &varname,
int table_id,
const Scope &scope);
// 6. recv sparse param
virtual void RpcRecvSparse(const std::string &varname, int table_id,
virtual void RpcRecvSparse(const std::string &varname,
int table_id,
Scope *scope);
// 7. send gloabl step
virtual void SendGlobalStep(const CommContext &ctx, int batches,
virtual void SendGlobalStep(const CommContext &ctx,
int batches,
Scope *send_scope);
virtual ~Communicator() {}
......@@ -303,7 +314,8 @@ class Communicator {
auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type);
rets.wait();
int status = rets.get();
PADDLE_ENFORCE_EQ(status, 0,
PADDLE_ENFORCE_EQ(status,
0,
platform::errors::InvalidArgument(
"The ret status must be 0 when barrier with table"));
}
......@@ -333,12 +345,19 @@ class Communicator {
template <typename T>
static Communicator *InitInstance(
const RpcCtxMap &send_ctx, const RecvCtxMap &recv_ctx,
const RpcCtxMap &send_ctx,
const RecvCtxMap &recv_ctx,
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list, Scope *recv_scope,
const std::vector<std::string> &host_sign_list,
Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>, send_ctx,
recv_ctx, dist_desc, host_sign_list, recv_scope,
std::call_once(init_flag_,
&Communicator::InitWithRpcCtx<T>,
send_ctx,
recv_ctx,
dist_desc,
host_sign_list,
recv_scope,
std::ref(envs));
return communicator_.get();
}
......@@ -456,15 +475,22 @@ class AsyncCommunicator : public Communicator {
void PushDensePostProcessing();
void PullSparseToTensorSync(
const uint64_t table_id, int fea_dim, uint64_t padding_id,
platform::Place place, bool is_training,
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
bool is_training,
std::vector<const framework::LoDTensor *> *inputs, // NOLINT
std::vector<framework::LoDTensor *> *outputs); // NOLINT
void PushSparseFromTensorAsync(
const uint64_t table_id, int fea_dim, uint64_t padding_id,
platform::Place place, std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows, const framework::LoDTensor *clicks,
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows,
const framework::LoDTensor *clicks,
std::vector<framework::LoDTensor *> *outputs);
protected:
......@@ -585,7 +611,8 @@ class GeoCommunicator : public AsyncCommunicator {
std::vector<int64_t> MergeSparseIds(const std::string &varname);
void SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, // NOLINT
int table_id, int ep_idx);
int table_id,
int ep_idx);
void RecvSparse(const std::string &varname, int table_id, int ep_idx);
void MainThread() override;
......@@ -628,8 +655,9 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;
std::unordered_map<std::string, paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>
std::unordered_map<
std::string,
paddle::framework::Channel<std::shared_ptr<std::vector<int64_t>>>>
sparse_id_queues_;
};
......
......@@ -25,13 +25,18 @@ namespace distributed {
struct CommContext {
CommContext() = default;
CommContext(const std::string &name, const std::vector<std::string> &names,
CommContext(const std::string &name,
const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false, int table_id_ = -1,
bool is_tensor_table_ = false, bool is_datanorm_table_ = false,
const std::vector<std::string> &origin_names,
int id,
bool merge_add_ = true,
bool is_sparse_ = true,
bool is_distributed_ = false,
int table_id_ = -1,
bool is_tensor_table_ = false,
bool is_datanorm_table_ = false,
int64_t program_id_ = -1)
: var_name(name),
splited_varnames(names),
......
......@@ -86,8 +86,10 @@ struct PSHost {
rank = std::stoi(endpoint_info[2]);
}
void StringSplit(const std::string &str, char sep,
std::vector<std::string> *pieces, bool ignore_null = true) {
void StringSplit(const std::string &str,
char sep,
std::vector<std::string> *pieces,
bool ignore_null = true) {
pieces->clear();
if (str.empty()) {
if (!ignore_null) {
......@@ -130,13 +132,15 @@ class PSEnvironment {
}
virtual uint64_t GetLocalHostSign() { return 0; }
virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
virtual int32_t RegistePsServer(const std::string &ip, uint32_t port,
virtual int32_t RegistePsServer(const std::string &ip,
uint32_t port,
int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set);
}
virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; }
virtual int32_t RegistePsClient(const std::string &ip, uint32_t port,
virtual int32_t RegistePsClient(const std::string &ip,
uint32_t port,
int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
}
......@@ -167,7 +171,9 @@ class PSEnvironment {
protected:
//注册一个host // NOLINT
virtual int32_t RegistePsHost(
const std::string &ip, uint32_t port, int32_t rank,
const std::string &ip,
uint32_t port,
int32_t rank,
std::vector<PSHost> &host_list, // NOLINT
std::unordered_set<uint64_t> &sign_set) { // NOLINT
PSHost host;
......@@ -209,7 +215,8 @@ class PaddlePSEnvironment : public PSEnvironment {
}
}
std::sort(
_ps_server_list.begin(), _ps_server_list.end(),
_ps_server_list.begin(),
_ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
......@@ -227,7 +234,8 @@ class PaddlePSEnvironment : public PSEnvironment {
}
}
std::sort(
_ps_server_list.begin(), _ps_server_list.end(),
_ps_server_list.begin(),
_ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
......@@ -244,7 +252,8 @@ class PaddlePSEnvironment : public PSEnvironment {
}
}
std::sort(
_ps_client_list.begin(), _ps_client_list.end(),
_ps_client_list.begin(),
_ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
......@@ -262,7 +271,8 @@ class PaddlePSEnvironment : public PSEnvironment {
}
}
std::sort(
_ps_client_list.begin(), _ps_client_list.end(),
_ps_client_list.begin(),
_ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
VLOG(1) << "env.set_ps_clients done\n";
return 0;
......
......@@ -40,7 +40,8 @@ class GraphPsService_Stub : public PsService_Stub {
public:
GraphPsService_Stub(::google::protobuf::RpcChannel* channel,
::google::protobuf::RpcChannel* local_channel = NULL,
GraphBrpcService* service = NULL, int thread_num = 1)
GraphBrpcService* service = NULL,
int thread_num = 1)
: PsService_Stub(channel) {
this->local_channel = local_channel;
this->graph_service = service;
......@@ -64,34 +65,50 @@ class GraphBrpcClient : public BrpcPsClient {
virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, int idx, std::vector<int64_t> node_ids,
int sample_size, std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight,
uint32_t table_id,
int idx,
std::vector<int64_t> node_ids,
int sample_size,
std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight,
bool need_weight,
int server_index = -1);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id, int type_id,
int idx, int server_index,
int start, int size, int step,
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int type_id,
int idx,
int server_index,
int start,
int size,
int step,
std::vector<FeatureNode>& res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int type_id, int idx,
int type_id,
int idx,
int server_index,
int sample_size,
std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id, int idx, const std::vector<int64_t>& node_ids,
const uint32_t& table_id,
int idx,
const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, int idx, const std::vector<int64_t>& node_ids,
const uint32_t& table_id,
int idx,
const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id, int type_id,
virtual std::future<int32_t> clear_nodes(uint32_t table_id,
int type_id,
int idx);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, int idx, std::vector<int64_t>& node_id_list,
uint32_t table_id,
int idx,
std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, int idx_, std::vector<int64_t>& node_id_list);
......@@ -107,8 +124,8 @@ class GraphBrpcClient : public BrpcPsClient {
}
GraphPsService_Stub getServiceStub(::google::protobuf::RpcChannel* channel,
int thread_num = 1) {
return GraphPsService_Stub(channel, local_channel, graph_service,
thread_num);
return GraphPsService_Stub(
channel, local_channel, graph_service, thread_num);
}
private:
......
此差异已折叠。
此差异已折叠。
......@@ -117,7 +117,8 @@ void HeterServer::WaitServerReady() {
}
int SendAndRecvVariableHandler::SaveInSwitchWithShard(
const MultiVarMsg* request, PsResponseMessage* response,
const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithShard";
int32_t group_id = request->group_id();
......@@ -174,7 +175,8 @@ int SendAndRecvVariableHandler::QueryInSwitchWithShard(
}
int SendAndRecvVariableHandler::SaveInSwitchWithScope(
const MultiVarMsg* request, PsResponseMessage* response,
const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithScope";
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
......@@ -201,8 +203,8 @@ int SendAndRecvVariableHandler::SaveInSwitchWithScope(
WaitForVarsConsumed(0, var_name);
}
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(*request, &request_io_buffer,
cpu_dev_ctx, local_scope);
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, cpu_dev_ctx, local_scope);
lk.unlock();
for (auto var_name : send_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册