未验证 提交 576236a0 编写于 作者: S Sing_chan 提交者: GitHub

format all files in fluid using new config (#43776)

上级 c40858c3

要显示的变更太多。

To preserve performance only 1000 of 1000+ files are displayed.
...@@ -41,10 +41,10 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) { ...@@ -41,10 +41,10 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
} }
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) { bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
return std::all_of(tensors.cbegin(), tensors.cend(), return std::all_of(
[&](const phi::DenseTensor& t) { tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) {
return platform::is_gpu_place(t.place()); return platform::is_gpu_place(t.place());
}); });
} }
} // namespace distributed } // namespace distributed
......
...@@ -28,7 +28,8 @@ HcclReduceOp ToHCCLRedType(ReduceOp reduction) { ...@@ -28,7 +28,8 @@ HcclReduceOp ToHCCLRedType(ReduceOp reduction) {
}; };
auto it = red_type.find(reduction); auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
it != red_type.end(), true, it != red_type.end(),
true,
platform::errors::InvalidArgument("Invalid hccl reduction. " platform::errors::InvalidArgument("Invalid hccl reduction. "
"Must be Min | Max | Prod | Sum")); "Must be Min | Max | Prod | Sum"));
return it->second; return it->second;
......
...@@ -67,11 +67,13 @@ class NPUEventManager { ...@@ -67,11 +67,13 @@ class NPUEventManager {
if (!is_created_) { if (!is_created_) {
CreateEvent(device_index); CreateEvent(device_index);
} }
PADDLE_ENFORCE_EQ(device_index, device_index_, PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"NPUDeviceContext's device %d does not match" "NPUDeviceContext's device %d does not match"
"Event's device %d", "Event's device %d",
device_index, device_index_)); device_index,
device_index_));
platform::NPUDeviceGuard guard(device_index_); platform::NPUDeviceGuard guard(device_index_);
platform::NPUEventRecord(event_, ctx.stream()); platform::NPUEventRecord(event_, ctx.stream());
...@@ -89,11 +91,13 @@ class NPUEventManager { ...@@ -89,11 +91,13 @@ class NPUEventManager {
void Block(const paddle::platform::NPUDeviceContext& ctx) const { void Block(const paddle::platform::NPUDeviceContext& ctx) const {
if (is_created_) { if (is_created_) {
auto device_index = ctx.GetPlace().device; auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index, device_index_, PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"CUDADeviceContext's device %d does not match" "CUDADeviceContext's device %d does not match"
"Event's device %d", "Event's device %d",
device_index, device_index_)); device_index,
device_index_));
platform::NPUDeviceGuard guard(device_index_); platform::NPUDeviceGuard guard(device_index_);
platform::NPUStreamWaitEvent(ctx.stream(), event_); platform::NPUStreamWaitEvent(ctx.stream(), event_);
} }
...@@ -126,12 +130,13 @@ class HCCLCommManager { ...@@ -126,12 +130,13 @@ class HCCLCommManager {
} }
} }
static std::shared_ptr<HCCLCommManager> Create(int num_ranks, int rank, static std::shared_ptr<HCCLCommManager> Create(int num_ranks,
int rank,
HcclRootInfo* comm_id, HcclRootInfo* comm_id,
HcclComm hccl_comm) { HcclComm hccl_comm) {
auto hccl_manager = std::make_shared<HCCLCommManager>(); auto hccl_manager = std::make_shared<HCCLCommManager>();
auto ret = platform::dynload::HcclCommInitRootInfo(num_ranks, comm_id, rank, auto ret = platform::dynload::HcclCommInitRootInfo(
&hccl_comm); num_ranks, comm_id, rank, &hccl_comm);
using __NPU_STATUS_TYPE__ = decltype(ret); using __NPU_STATUS_TYPE__ = decltype(ret);
constexpr auto __success_type__ = constexpr auto __success_type__ =
platform::details::NPUStatusType<__NPU_STATUS_TYPE__>::kSuccess; platform::details::NPUStatusType<__NPU_STATUS_TYPE__>::kSuccess;
......
...@@ -27,7 +27,8 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction) { ...@@ -27,7 +27,8 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction) {
{ReduceOp::PRODUCT, ncclProd}, {ReduceOp::PRODUCT, ncclProd},
}; };
auto it = red_type.find(reduction); auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ(it != red_type.end(), true, PADDLE_ENFORCE_EQ(it != red_type.end(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Invalid nccl reduction. Must be ncclMin | ncclMax | " "Invalid nccl reduction. Must be ncclMin | ncclMax | "
"ncclProd | ncclSum")); "ncclProd | ncclSum"));
......
...@@ -47,14 +47,16 @@ ...@@ -47,14 +47,16 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
#define NCCLCHECK(cmd) \ #define NCCLCHECK(cmd) \
do { \ do { \
ncclResult_t r = cmd; \ ncclResult_t r = cmd; \
if (r != ncclSuccess) { \ if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \ printf("Failed, NCCL error %s:%d '%s'\n", \
platform::dynload::ncclGetErrorString(r)); \ __FILE__, \
exit(EXIT_FAILURE); \ __LINE__, \
} \ platform::dynload::ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0) } while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper. // NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
...@@ -107,11 +109,13 @@ class EventManager { ...@@ -107,11 +109,13 @@ class EventManager {
if (!is_created_) { if (!is_created_) {
CreateEvent(device_index); CreateEvent(device_index);
} }
PADDLE_ENFORCE_EQ(device_index, device_index_, PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"CUDADeviceContext's device %d does not match" "CUDADeviceContext's device %d does not match"
"Event's device %d", "Event's device %d",
device_index, device_index_)); device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_); platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -156,11 +160,13 @@ class EventManager { ...@@ -156,11 +160,13 @@ class EventManager {
void Block(const paddle::platform::CUDADeviceContext& ctx) const { void Block(const paddle::platform::CUDADeviceContext& ctx) const {
if (is_created_) { if (is_created_) {
auto device_index = ctx.GetPlace().device; auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index, device_index_, PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"CUDADeviceContext's device %d does not match" "CUDADeviceContext's device %d does not match"
"Event's device %d", "Event's device %d",
device_index, device_index_)); device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_); platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -213,11 +219,12 @@ class NCCLCommManager { ...@@ -213,11 +219,12 @@ class NCCLCommManager {
} }
} }
static std::shared_ptr<NCCLCommManager> Create(int num_ranks, int rank, static std::shared_ptr<NCCLCommManager> Create(int num_ranks,
int rank,
ncclUniqueId comm_id) { ncclUniqueId comm_id) {
auto nccl_manager = std::make_shared<NCCLCommManager>(); auto nccl_manager = std::make_shared<NCCLCommManager>();
NCCLCHECK(platform::dynload::ncclCommInitRank(&(nccl_manager->nccl_comm_), NCCLCHECK(platform::dynload::ncclCommInitRank(
num_ranks, comm_id, rank)); &(nccl_manager->nccl_comm_), num_ranks, comm_id, rank));
nccl_manager->nccl_id_ = comm_id; nccl_manager->nccl_id_ = comm_id;
nccl_manager->rank_ = rank; nccl_manager->rank_ = rank;
......
...@@ -35,7 +35,9 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) { ...@@ -35,7 +35,9 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) {
void ProcessGroup::Task::Synchronize() {} void ProcessGroup::Task::Synchronize() {}
ProcessGroup::ProcessGroup(int rank, int size, const platform::Place& place, ProcessGroup::ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid) int gid)
: rank_(rank), size_(size), place_(place), gid_(gid) { : rank_(rank), size_(size), place_(place), gid_(gid) {
if (gid != IGNORE_ID) { if (gid != IGNORE_ID) {
......
...@@ -53,7 +53,8 @@ class ProcessGroup { ...@@ -53,7 +53,8 @@ class ProcessGroup {
public: public:
class Task { class Task {
public: public:
Task(int rank, const std::vector<phi::DenseTensor>& inputTensors, Task(int rank,
const std::vector<phi::DenseTensor>& inputTensors,
CommType opType = CommType::UNKNOWN); CommType opType = CommType::UNKNOWN);
virtual ~Task(); virtual ~Task();
...@@ -68,7 +69,9 @@ class ProcessGroup { ...@@ -68,7 +69,9 @@ class ProcessGroup {
bool is_completed_ = false; bool is_completed_ = false;
}; };
explicit ProcessGroup(int rank, int size, const platform::Place& place, explicit ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid); int gid);
virtual ~ProcessGroup() {} virtual ~ProcessGroup() {}
...@@ -113,7 +116,8 @@ class ProcessGroup { ...@@ -113,7 +116,8 @@ class ProcessGroup {
} }
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor&, virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor&,
int, int, int,
int,
int) { // NOLINT int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName())); "ProcessGroup%s does not support send", GetBackendName()));
......
...@@ -165,8 +165,11 @@ ProcessGroupGloo::GlooTask::GlooTask( ...@@ -165,8 +165,11 @@ ProcessGroupGloo::GlooTask::GlooTask(
: ProcessGroup::Task(rank, inputs, comm_type) {} : ProcessGroup::Task(rank, inputs, comm_type) {}
ProcessGroupGloo::ProcessGroupGloo( ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<distributed::Store>& store, int rank, int world_size, const std::shared_ptr<distributed::Store>& store,
const platform::Place& place, int gid, int rank,
int world_size,
const platform::Place& place,
int gid,
const std::shared_ptr<GlooOptions> options) const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size, place, gid), : ProcessGroup(rank, world_size, place, gid),
_tag(0), _tag(0),
...@@ -182,7 +185,9 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -182,7 +185,9 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
BroadcastGlooTask(const std::shared_ptr<gloo::Context>& context, BroadcastGlooTask(const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
int rank, int root, uint32_t tag) int rank,
int root,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST), : ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST),
_context(context), _context(context),
_root(root), _root(root),
...@@ -214,23 +219,26 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -214,23 +219,26 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, const BroadcastOptions& opts) { std::vector<phi::DenseTensor>& outputs,
const BroadcastOptions& opts) {
auto root = opts.source_rank; auto root = opts.source_rank;
std::unique_ptr<BroadcastGlooTask> task; std::unique_ptr<BroadcastGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto context = get_context();
task = std::make_unique<BroadcastGlooTask>(context, inputs, outputs, rank_, task = std::make_unique<BroadcastGlooTask>(
root, tag); context, inputs, outputs, rank_, root, tag);
task->Run(); task->Run();
return task; return task;
} }
class AllreduceGlooTask : public ProcessGroupGloo::GlooTask { class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
AllreduceGlooTask(int rank, const std::shared_ptr<gloo::Context>& context, AllreduceGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op, uint32_t tag) ReduceOp reduce_op,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE), : ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE),
_context(context), _context(context),
_inputs(inputs), _inputs(inputs),
...@@ -274,12 +282,13 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -274,12 +282,13 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, const AllreduceOptions& opts) { std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts) {
auto tag = next_tag(); auto tag = next_tag();
std::shared_ptr<GlooTask> task; std::shared_ptr<GlooTask> task;
auto context = get_context(); auto context = get_context();
task = std::make_shared<AllreduceGlooTask>(rank_, context, inputs, outputs, task = std::make_shared<AllreduceGlooTask>(
opts.reduce_op, tag); rank_, context, inputs, outputs, opts.reduce_op, tag);
task->Run(); task->Run();
return task; return task;
} }
...@@ -287,8 +296,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce( ...@@ -287,8 +296,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
class BarrierGlooTask : public ProcessGroupGloo::GlooTask { class BarrierGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
BarrierGlooTask(int rank, const std::shared_ptr<gloo::Context>& context) BarrierGlooTask(int rank, const std::shared_ptr<gloo::Context>& context)
: ProcessGroupGloo::GlooTask(rank, std::vector<phi::DenseTensor>{}, : ProcessGroupGloo::GlooTask(
CommType::BARRIER), rank, std::vector<phi::DenseTensor>{}, CommType::BARRIER),
_context(context) {} _context(context) {}
void Run() override { _do_barrier(); } void Run() override { _do_barrier(); }
...@@ -313,7 +322,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier( ...@@ -313,7 +322,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier(
class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
AllgatherGlooTask(int rank, const std::shared_ptr<gloo::Context>& context, AllgatherGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
uint32_t tag) uint32_t tag)
...@@ -348,18 +358,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather( ...@@ -348,18 +358,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::shared_ptr<AllgatherGlooTask> task; std::shared_ptr<AllgatherGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto context = get_context();
task = std::make_shared<AllgatherGlooTask>(rank_, context, in_tensors, task = std::make_shared<AllgatherGlooTask>(
out_tensors, tag); rank_, context, in_tensors, out_tensors, tag);
task->Run(); task->Run();
return task; return task;
} }
class ReduceGlooTask : public ProcessGroupGloo::GlooTask { class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
ReduceGlooTask(int rank, const std::shared_ptr<gloo::Context>& context, ReduceGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op, int dst, uint32_t tag) ReduceOp reduce_op,
int dst,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::REDUCE), : ProcessGroupGloo::GlooTask(rank, inputs, CommType::REDUCE),
_context(context), _context(context),
_inputs(inputs), _inputs(inputs),
...@@ -407,22 +420,26 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -407,22 +420,26 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, const ReduceOptions& opts) { std::vector<phi::DenseTensor>& outputs,
const ReduceOptions& opts) {
std::shared_ptr<ReduceGlooTask> task; std::shared_ptr<ReduceGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto context = get_context();
task = std::make_shared<ReduceGlooTask>(rank_, context, inputs, outputs, task = std::make_shared<ReduceGlooTask>(
opts.reduce_op, opts.root_rank, tag); rank_, context, inputs, outputs, opts.reduce_op, opts.root_rank, tag);
task->Run(); task->Run();
return task; return task;
} }
class ScatterGlooTask : public ProcessGroupGloo::GlooTask { class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
ScatterGlooTask(int rank, const std::shared_ptr<gloo::Context>& context, ScatterGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
int src, int size, uint32_t tag) int src,
int size,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::SCATTER), : ProcessGroupGloo::GlooTask(rank, inputs, CommType::SCATTER),
_context(context), _context(context),
_inputs(inputs), _inputs(inputs),
...@@ -458,7 +475,8 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -458,7 +475,8 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const ScatterOptions& opts) { std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) {
std::shared_ptr<ScatterGlooTask> task; std::shared_ptr<ScatterGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto context = get_context();
...@@ -487,7 +505,8 @@ ProcessGroupGloo::createDefaultDevice() { ...@@ -487,7 +505,8 @@ ProcessGroupGloo::createDefaultDevice() {
std::array<char, HOST_NAME_MAX> hostname{}; std::array<char, HOST_NAME_MAX> hostname{};
auto ret = ::gethostname(hostname.data(), HOST_NAME_MAX); auto ret = ::gethostname(hostname.data(), HOST_NAME_MAX);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret, 0, ret,
0,
platform::errors::Fatal("Get hostname error for createDefaultDevice.")); platform::errors::Fatal("Get hostname error for createDefaultDevice."));
::addrinfo* result; ::addrinfo* result;
result = tcputils::get_addr_info(hostname.data(), "", 0, AF_UNSPEC); result = tcputils::get_addr_info(hostname.data(), "", 0, AF_UNSPEC);
......
...@@ -101,8 +101,11 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -101,8 +101,11 @@ class ProcessGroupGloo : public ProcessGroup {
}; };
explicit ProcessGroupGloo( explicit ProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, int rank, const std::shared_ptr<paddle::distributed::Store>& store,
int world_size, const platform::Place& place, int gid, int rank,
int world_size,
const platform::Place& place,
int gid,
std::shared_ptr<GlooOptions> options); std::shared_ptr<GlooOptions> options);
~ProcessGroupGloo() = default; ~ProcessGroupGloo() = default;
......
...@@ -45,14 +45,18 @@ void SyncDefaultStream( ...@@ -45,14 +45,18 @@ void SyncDefaultStream(
} }
std::shared_ptr<ProcessGroupHCCL::HCCLTask> ProcessGroupHCCL::CreateTask( std::shared_ptr<ProcessGroupHCCL::HCCLTask> ProcessGroupHCCL::CreateTask(
std::vector<Place> places, int rank, CommType comm_type, std::vector<Place> places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) { const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupHCCL::HCCLTask>(places, rank, comm_type, return std::make_shared<ProcessGroupHCCL::HCCLTask>(
inputs); places, rank, comm_type, inputs);
} }
ProcessGroupHCCL::HCCLTask::HCCLTask( ProcessGroupHCCL::HCCLTask::HCCLTask(
const std::vector<Place>& places, int rank, CommType CommType, const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs) const std::vector<phi::DenseTensor>& inputs)
: Task(rank, inputs, CommType), places_(places) { : Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size()); control_events_.resize(places.size());
...@@ -99,8 +103,10 @@ bool ProcessGroupHCCL::HCCLTask::Wait(std::chrono::milliseconds timeout) { ...@@ -99,8 +103,10 @@ bool ProcessGroupHCCL::HCCLTask::Wait(std::chrono::milliseconds timeout) {
void ProcessGroupHCCL::HCCLTask::Synchronize() { Wait(kWaitTimeout); } void ProcessGroupHCCL::HCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupHCCL::ProcessGroupHCCL(const std::shared_ptr<Store>& store, ProcessGroupHCCL::ProcessGroupHCCL(const std::shared_ptr<Store>& store,
int rank, int size, int rank,
const platform::Place& place, int gid) int size,
const platform::Place& place,
int gid)
: ProcessGroup(rank, size, place, gid), store_(store) { : ProcessGroup(rank, size, place, gid), store_(store) {
platform::SetNPUDeviceId(place_.device); platform::SetNPUDeviceId(place_.device);
} }
...@@ -127,7 +133,8 @@ void ProcessGroupHCCL::BroadcastUniqueHCCLID( ...@@ -127,7 +133,8 @@ void ProcessGroupHCCL::BroadcastUniqueHCCLID(
// create HCCLManager cache for places_key // create HCCLManager cache for places_key
void ProcessGroupHCCL::CreateHCCLManagerCache( void ProcessGroupHCCL::CreateHCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) { const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(places_key.empty(), false, PADDLE_ENFORCE_EQ(places_key.empty(),
false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Not able to create/get the HCCL Communicator since " "Not able to create/get the HCCL Communicator since "
"the NPU place are not known")); "the NPU place are not known"));
...@@ -155,8 +162,8 @@ void ProcessGroupHCCL::CreateHCCLManagerCache( ...@@ -155,8 +162,8 @@ void ProcessGroupHCCL::CreateHCCLManagerCache(
std::unique_ptr<HcclComm[]> comms(new HcclComm[places.size()]); std::unique_ptr<HcclComm[]> comms(new HcclComm[places.size()]);
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId()); platform::NPUDeviceGuard guard(places[i].GetDeviceId());
hccl_comms[i] = HCCLCommManager::Create(GetSize(), GetRank(), &hccl_id, hccl_comms[i] = HCCLCommManager::Create(
comms.get() + i); GetSize(), GetRank(), &hccl_id, comms.get() + i);
dev_ctx[i].reset(new NPUDeviceContext(places[i])); dev_ctx[i].reset(new NPUDeviceContext(places[i]));
} }
...@@ -172,7 +179,9 @@ void ProcessGroupHCCL::CreateHCCLManagerCache( ...@@ -172,7 +179,9 @@ void ProcessGroupHCCL::CreateHCCLManagerCache(
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective( std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, Fn fn, CommType op_type) { std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type) {
const auto places = GetPlaceList(inputs); const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places); const auto key = GetKeyFromPlaces(places);
...@@ -218,13 +227,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::AllReduce( ...@@ -218,13 +227,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::AllReduce(
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const AllreduceOptions& opts) { const AllreduceOptions& opts) {
return Collective( return Collective(
in_tensors, out_tensors, in_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, HcclComm comm, out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
HcclComm comm,
const aclrtStream& stream) { const aclrtStream& stream) {
return platform::dynload::HcclAllReduce( return platform::dynload::HcclAllReduce(
input.data(), output.data(), input.numel(), input.data(),
output.data(),
input.numel(),
platform::ToHCCLDataType(input.dtype()), platform::ToHCCLDataType(input.dtype()),
ToHCCLRedType(opts.reduce_op), comm, stream); ToHCCLRedType(opts.reduce_op),
comm,
stream);
}, },
CommType::ALLREDUCE); CommType::ALLREDUCE);
} }
...@@ -239,18 +255,29 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Broadcast( ...@@ -239,18 +255,29 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Broadcast(
// CudaPlace.")); // CudaPlace."));
return Collective( return Collective(
in_tensors, out_tensors, in_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, HcclComm comm, out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
HcclComm comm,
const aclrtStream& stream) { const aclrtStream& stream) {
int root = opts.source_rank * in_tensors.size() + opts.source_root; int root = opts.source_rank * in_tensors.size() + opts.source_root;
if (rank_ == root) { if (rank_ == root) {
return platform::dynload::HcclBroadcast( return platform::dynload::HcclBroadcast(
input.data(), input.numel(), input.data(),
platform::ToHCCLDataType(input.dtype()), root, comm, stream); input.numel(),
platform::ToHCCLDataType(input.dtype()),
root,
comm,
stream);
} else { } else {
return platform::dynload::HcclBroadcast( return platform::dynload::HcclBroadcast(
output.data(), output.numel(), output.data(),
platform::ToHCCLDataType(output.dtype()), root, comm, stream); output.numel(),
platform::ToHCCLDataType(output.dtype()),
root,
comm,
stream);
} }
}, },
CommType::BROADCAST); CommType::BROADCAST);
......
...@@ -44,7 +44,9 @@ class ProcessGroupHCCL : public ProcessGroup { ...@@ -44,7 +44,9 @@ class ProcessGroupHCCL : public ProcessGroup {
class HCCLTask : public ProcessGroup::Task, class HCCLTask : public ProcessGroup::Task,
public std::enable_shared_from_this<HCCLTask> { public std::enable_shared_from_this<HCCLTask> {
public: public:
HCCLTask(const std::vector<Place>& places, int rank, CommType CommType, HCCLTask(const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs); const std::vector<phi::DenseTensor>& inputs);
bool IsCompleted(); bool IsCompleted();
...@@ -69,8 +71,11 @@ class ProcessGroupHCCL : public ProcessGroup { ...@@ -69,8 +71,11 @@ class ProcessGroupHCCL : public ProcessGroup {
private: private:
}; };
ProcessGroupHCCL(const std::shared_ptr<Store>& store, int rank, int size, ProcessGroupHCCL(const std::shared_ptr<Store>& store,
const platform::Place& place, int gid); int rank,
int size,
const platform::Place& place,
int gid);
const std::string GetBackendName() const override { const std::string GetBackendName() const override {
return std::string(HCCL_BACKEND_NAME); return std::string(HCCL_BACKEND_NAME);
...@@ -88,7 +93,9 @@ class ProcessGroupHCCL : public ProcessGroup { ...@@ -88,7 +93,9 @@ class ProcessGroupHCCL : public ProcessGroup {
protected: protected:
virtual std::shared_ptr<ProcessGroupHCCL::HCCLTask> CreateTask( virtual std::shared_ptr<ProcessGroupHCCL::HCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType, std::vector<Place> places,
int rank,
CommType opType,
const std::vector<phi::DenseTensor>& inputs); const std::vector<phi::DenseTensor>& inputs);
std::shared_ptr<Store> store_; std::shared_ptr<Store> store_;
...@@ -107,7 +114,8 @@ class ProcessGroupHCCL : public ProcessGroup { ...@@ -107,7 +114,8 @@ class ProcessGroupHCCL : public ProcessGroup {
std::set<int> used_place_ids_; std::set<int> used_place_ids_;
private: private:
void BcastHCCLId(std::vector<HcclRootInfo>& hccl_ids, int root, // NOLINT void BcastHCCLId(std::vector<HcclRootInfo>& hccl_ids,
int root, // NOLINT
int server_fd); int server_fd);
void BroadcastUniqueHCCLID(std::vector<HcclRootInfo>& hccl_ids); // NOLINT void BroadcastUniqueHCCLID(std::vector<HcclRootInfo>& hccl_ids); // NOLINT
...@@ -116,7 +124,8 @@ class ProcessGroupHCCL : public ProcessGroup { ...@@ -116,7 +124,8 @@ class ProcessGroupHCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Collective( std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn, CommType op_type); Fn fn,
CommType op_type);
void CreateHCCLManagerCache(const std::string& places_key, void CreateHCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places); const std::vector<Place>& places);
......
...@@ -32,8 +32,8 @@ int ProcessGroupHeter::recv_count = 0; ...@@ -32,8 +32,8 @@ int ProcessGroupHeter::recv_count = 0;
std::shared_ptr<ProcessGroupHeter::HeterTask> ProcessGroupHeter::CreateTask( std::shared_ptr<ProcessGroupHeter::HeterTask> ProcessGroupHeter::CreateTask(
int rank, CommType comm_type, const std::vector<phi::DenseTensor>& inputs) { int rank, CommType comm_type, const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupHeter::HeterTask>(rank, comm_type, return std::make_shared<ProcessGroupHeter::HeterTask>(
inputs); rank, comm_type, inputs);
} }
ProcessGroupHeter::HeterTask::HeterTask( ProcessGroupHeter::HeterTask::HeterTask(
...@@ -49,11 +49,19 @@ bool ProcessGroupHeter::HeterTask::Wait(std::chrono::milliseconds timeout) { ...@@ -49,11 +49,19 @@ bool ProcessGroupHeter::HeterTask::Wait(std::chrono::milliseconds timeout) {
return true; return true;
} }
ProcessGroupHeter::ProcessGroupHeter( ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store,
const std::shared_ptr<Store>& store, int rank, int size, int rank,
const platform::Place& place, int gid, int local_rank, int local_size, int size,
int gloo_rank, int gloo_size, bool with_switch, std::string switch_endpoint, const platform::Place& place,
int src_rank, int dst_rank) int gid,
int local_rank,
int local_size,
int gloo_rank,
int gloo_size,
bool with_switch,
std::string switch_endpoint,
int src_rank,
int dst_rank)
: ProcessGroup(rank, size, place, gid), : ProcessGroup(rank, size, place, gid),
store_(store), store_(store),
local_rank_(local_rank), local_rank_(local_rank),
...@@ -66,11 +74,11 @@ ProcessGroupHeter::ProcessGroupHeter( ...@@ -66,11 +74,11 @@ ProcessGroupHeter::ProcessGroupHeter(
dst_rank_(dst_rank) { dst_rank_(dst_rank) {
return; return;
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
inner_pg_ = std::make_shared<ProcessGroupNCCL>(store, local_rank, local_size, inner_pg_ = std::make_shared<ProcessGroupNCCL>(
place_, IGNORE_ID); store, local_rank, local_size, place_, IGNORE_ID);
#elif defined(PADDLE_WITH_ASCEND_CL) #elif defined(PADDLE_WITH_ASCEND_CL)
inner_pg_ = std::make_shared<ProcessGroupHCCL>(store, local_rank, local_size, inner_pg_ = std::make_shared<ProcessGroupHCCL>(
place_, IGNORE_ID); store, local_rank, local_size, place_, IGNORE_ID);
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"ProcessGroupHeter only supports NCCL and HCCL now."); "ProcessGroupHeter only supports NCCL and HCCL now.");
...@@ -94,13 +102,16 @@ static void _do_add(T* dst, T* src, size_t size) { ...@@ -94,13 +102,16 @@ static void _do_add(T* dst, T* src, size_t size) {
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const AllreduceOptions& opts) { std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts) {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true, CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true, CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace.")); platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
#endif #endif
...@@ -128,10 +139,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce( ...@@ -128,10 +139,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::vector<int64_t> send_size; std::vector<int64_t> send_size;
send_size.push_back(dense_cpu_tensor.numel()); send_size.push_back(dense_cpu_tensor.numel());
int ret = client_->Send( int ret = client_->Send(
gid_, {dense_cpu_tensor.name()}, send_size, dense_cpu_tensor.data(), gid_,
{dense_cpu_tensor.name()},
send_size,
dense_cpu_tensor.data(),
dense_cpu_tensor.numel() * dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype())); framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Send to the switch module error.")); "Send to the switch module error."));
phi::DenseTensor cpu_tensor2; phi::DenseTensor cpu_tensor2;
...@@ -139,11 +154,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce( ...@@ -139,11 +154,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::make_unique<paddle::experimental::DefaultAllocator>( std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace()) paddle::platform::CPUPlace())
.get(), .get(),
dense_cpu_tensor.dtype(), dense_cpu_tensor.numel()); dense_cpu_tensor.dtype(),
dense_cpu_tensor.numel());
ret = client_->Recv( ret = client_->Recv(
gid_, {dense_cpu_tensor.name()}, cpu_tensor2.data(), gid_,
{dense_cpu_tensor.name()},
cpu_tensor2.data(),
cpu_tensor2.numel() * framework::DataTypeSize(cpu_tensor2.dtype())); cpu_tensor2.numel() * framework::DataTypeSize(cpu_tensor2.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Recv from the switch module error.")); "Recv from the switch module error."));
...@@ -192,13 +211,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce( ...@@ -192,13 +211,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const BroadcastOptions& opts) { std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts) {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true, CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true, CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace.")); platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
#endif #endif
...@@ -226,19 +248,25 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast( ...@@ -226,19 +248,25 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
std::vector<int64_t> send_size; std::vector<int64_t> send_size;
send_size.push_back(dense_cpu_tensor.numel()); send_size.push_back(dense_cpu_tensor.numel());
int ret = client_->Send( int ret = client_->Send(
gid_, {dense_cpu_tensor.name()}, send_size, gid_,
{dense_cpu_tensor.name()},
send_size,
dense_cpu_tensor.data(), dense_cpu_tensor.data(),
dense_cpu_tensor.numel() * dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype())); framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Send to the switch module error.")); "Send to the switch module error."));
} else { } else {
int ret = client_->Recv( int ret = client_->Recv(
gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor.data(), gid_,
{dense_cpu_tensor.name()},
dense_cpu_tensor.data(),
dense_cpu_tensor.numel() * dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype())); framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Receive from the switch module error.")); "Receive from the switch module error."));
} }
...@@ -261,7 +289,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast( ...@@ -261,7 +289,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send( std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(
std::vector<phi::DenseTensor>& in_tensors, int peer) { std::vector<phi::DenseTensor>& in_tensors, int peer) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_tensors.size(), 1, in_tensors.size(),
1,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"For each send operation, there can only be one tensor to send.")); "For each send operation, there can only be one tensor to send."));
// Copy Tensor to cpu // Copy Tensor to cpu
...@@ -269,7 +298,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send( ...@@ -269,7 +298,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(
phi::DenseTensor cpu_tensor; phi::DenseTensor cpu_tensor;
auto& gpu_tensor = in_tensors[0]; auto& gpu_tensor = in_tensors[0];
framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor); framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor);
PADDLE_ENFORCE_EQ(with_switch_, true, PADDLE_ENFORCE_EQ(with_switch_,
true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Gloo does not support the send operation.")); "Gloo does not support the send operation."));
auto end = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now();
...@@ -289,10 +319,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send( ...@@ -289,10 +319,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(
std::string tensor_name = std::to_string(gid_) + "_id_" + std::to_string(id) + std::string tensor_name = std::to_string(gid_) + "_id_" + std::to_string(id) +
std::string("_") + std::to_string(send_count++); std::string("_") + std::to_string(send_count++);
VLOG(2) << "tensor_name:" << tensor_name; VLOG(2) << "tensor_name:" << tensor_name;
int ret = client_->Send(gid_, {tensor_name}, send_size, cpu_tensor.data(), int ret = client_->Send(
tensor_size); gid_, {tensor_name}, send_size, cpu_tensor.data(), tensor_size);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret, 0, ret,
0,
platform::errors::PreconditionNotMet("Send to the switch module error.")); platform::errors::PreconditionNotMet("Send to the switch module error."));
return CreateTask(rank_, CommType::SEND, in_tensors); return CreateTask(rank_, CommType::SEND, in_tensors);
} }
...@@ -300,7 +331,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send( ...@@ -300,7 +331,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv( std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv(
std::vector<phi::DenseTensor>& out_tensors, int peer) { std::vector<phi::DenseTensor>& out_tensors, int peer) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_tensors.size(), 1, out_tensors.size(),
1,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"For each rece operation, there can only be one tensor to receive.")); "For each rece operation, there can only be one tensor to receive."));
...@@ -311,7 +343,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv( ...@@ -311,7 +343,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv(
cpu_tensor.set_layout(gpu_tensor.layout()); cpu_tensor.set_layout(gpu_tensor.layout());
cpu_tensor.mutable_data(platform::CPUPlace(), gpu_tensor.dtype()); cpu_tensor.mutable_data(platform::CPUPlace(), gpu_tensor.dtype());
PADDLE_ENFORCE_EQ(with_switch_, true, PADDLE_ENFORCE_EQ(with_switch_,
true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Gloo does not support the send operation.")); "Gloo does not support the send operation."));
// recv from switch // recv from switch
...@@ -323,9 +356,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv( ...@@ -323,9 +356,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv(
VLOG(2) << "tensor_name: " << tensor_name; VLOG(2) << "tensor_name: " << tensor_name;
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
int ret = client_->Recv( int ret = client_->Recv(
gid_, {tensor_name}, cpu_tensor.data(), gid_,
{tensor_name},
cpu_tensor.data(),
cpu_tensor.numel() * framework::DataTypeSize(cpu_tensor.dtype())); cpu_tensor.numel() * framework::DataTypeSize(cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, PADDLE_ENFORCE_EQ(ret,
0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"receive to the switch module error.")); "receive to the switch module error."));
auto end = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now();
......
...@@ -66,7 +66,8 @@ class ProcessGroupHeter : public ProcessGroup { ...@@ -66,7 +66,8 @@ class ProcessGroupHeter : public ProcessGroup {
class HeterTask : public ProcessGroup::Task, class HeterTask : public ProcessGroup::Task,
public std::enable_shared_from_this<HeterTask> { public std::enable_shared_from_this<HeterTask> {
public: public:
HeterTask(int rank, CommType CommType, HeterTask(int rank,
CommType CommType,
const std::vector<phi::DenseTensor>&); const std::vector<phi::DenseTensor>&);
bool IsCompleted(); bool IsCompleted();
...@@ -80,22 +81,32 @@ class ProcessGroupHeter : public ProcessGroup { ...@@ -80,22 +81,32 @@ class ProcessGroupHeter : public ProcessGroup {
virtual ~HeterTask(); virtual ~HeterTask();
}; };
ProcessGroupHeter(const std::shared_ptr<Store>& store, int rank, int size, ProcessGroupHeter(const std::shared_ptr<Store>& store,
const platform::Place& place, int gid, int local_rank, int rank,
int local_size, int gloo_rank, int gloo_size, int size,
bool with_switch, std::string switch_endpoints, const platform::Place& place,
int src_rank, int dst_rank); int gid,
int local_rank,
int local_size,
int gloo_rank,
int gloo_size,
bool with_switch,
std::string switch_endpoints,
int src_rank,
int dst_rank);
const std::string GetBackendName() const override { const std::string GetBackendName() const override {
return std::string(HETER_BACKEND_NAME); return std::string(HETER_BACKEND_NAME);
} }
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>&, std::vector<phi::DenseTensor>&, std::vector<phi::DenseTensor>&,
std::vector<phi::DenseTensor>&,
const AllreduceOptions& = AllreduceOptions()) override; const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>&, std::vector<phi::DenseTensor>&, std::vector<phi::DenseTensor>&,
std::vector<phi::DenseTensor>&,
const BroadcastOptions& = BroadcastOptions()) override; const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send( std::shared_ptr<ProcessGroup::Task> Send(
......
...@@ -42,14 +42,18 @@ void SyncDefaultStream( ...@@ -42,14 +42,18 @@ void SyncDefaultStream(
} }
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask( std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
std::vector<Place> places, int rank, CommType comm_type, std::vector<Place> places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) { const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>(places, rank, comm_type, return std::make_shared<ProcessGroupNCCL::NCCLTask>(
inputs); places, rank, comm_type, inputs);
} }
ProcessGroupNCCL::NCCLTask::NCCLTask( ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places, int rank, CommType CommType, const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs) const std::vector<phi::DenseTensor>& inputs)
: Task(rank, inputs, CommType), places_(places) { : Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size()); control_events_.resize(places.size());
...@@ -109,8 +113,10 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { ...@@ -109,8 +113,10 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); } void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store, ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int size, int rank,
const platform::Place& place, int gid) int size,
const platform::Place& place,
int gid)
: ProcessGroup(rank, size, place, gid), store_(store) { : ProcessGroup(rank, size, place, gid), store_(store) {
platform::SetDeviceId(place_.device); platform::SetDeviceId(place_.device);
} }
...@@ -139,7 +145,8 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID( ...@@ -139,7 +145,8 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(
// create NCCLManager cache for places_key // create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache( void ProcessGroupNCCL::CreateNCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) { const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(places_key.empty(), false, PADDLE_ENFORCE_EQ(places_key.empty(),
false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Not able to create/get the NCCL Communicator since " "Not able to create/get the NCCL Communicator since "
"the GPU place are not known")); "the GPU place are not known"));
...@@ -190,7 +197,9 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ...@@ -190,7 +197,9 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, Fn fn, CommType op_type) { std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type) {
const auto places = GetPlaceList(inputs); const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places); const auto key = GetKeyFromPlaces(places);
...@@ -237,7 +246,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -237,7 +246,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
template <typename Fn> template <typename Fn>
void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
phi::DenseTensor* out, Fn fn, phi::DenseTensor* out,
Fn fn,
CommType op_type) { CommType op_type) {
std::vector<Place> places; std::vector<Place> places;
places.push_back(in->place()); places.push_back(in->place());
...@@ -274,7 +284,9 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, ...@@ -274,7 +284,9 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::vector<phi::DenseTensor>& tensors, Fn fn, int dst_rank, std::vector<phi::DenseTensor>& tensors,
Fn fn,
int dst_rank,
CommType op_type) { CommType op_type) {
const auto places = GetPlaceList(tensors); const auto places = GetPlaceList(tensors);
const auto key = GetKeyFromPlaces(places); const auto key = GetKeyFromPlaces(places);
...@@ -321,38 +333,57 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -321,38 +333,57 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const AllreduceOptions& opts) { std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true, CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, out_tensors, in_tensors,
[&](const phi::DenseTensor& input, phi::DenseTensor& output, out_tensors,
ncclComm_t comm, const gpuStream_t& stream) { [&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllReduce( return platform::dynload::ncclAllReduce(
input.data(), output.data(), input.numel(), input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.type()), platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op), comm, stream); ToNCCLRedType(opts.reduce_op),
comm,
stream);
}, },
CommType::ALLREDUCE); CommType::ALLREDUCE);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const BroadcastOptions& opts) { std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true, CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, out_tensors, in_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, ncclComm_t comm, out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
const auto root = const auto root =
opts.source_rank * in_tensors.size() + opts.source_root; opts.source_rank * in_tensors.size() + opts.source_root;
return platform::dynload::ncclBroadcast( return platform::dynload::ncclBroadcast(
input.data(), output.data(), input.numel(), input.data(),
platform::ToNCCLDataType(input.type()), root, comm, stream); output.data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
root,
comm,
stream);
}, },
CommType::BROADCAST); CommType::BROADCAST);
} }
...@@ -381,22 +412,26 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier( ...@@ -381,22 +412,26 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
void CheckTensorsInDifferentDevices( void CheckTensorsInDifferentDevices(
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) { const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensors.size() == 0, false, tensors.size() == 0,
false,
platform::errors::InvalidArgument("Tensor list must be nonempty.")); platform::errors::InvalidArgument("Tensor list must be nonempty."));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
tensors.size(), num_devices, tensors.size(),
num_devices,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Tensor list mustn't be larger than the number of available GPUs.")); "Tensor list mustn't be larger than the number of available GPUs."));
std::set<Place> used_devices; std::set<Place> used_devices;
for (const auto& t : tensors) { for (const auto& t : tensors) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()), true, PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Tensors must be CUDA and dense tensor.")); "Tensors must be CUDA and dense tensor."));
const auto inserted = used_devices.insert(t.place()).second; const auto inserted = used_devices.insert(t.place()).second;
PADDLE_ENFORCE_EQ(inserted, true, PADDLE_ENFORCE_EQ(inserted,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Tensors must be on distinct GPU devices.")); "Tensors must be on distinct GPU devices."));
} }
...@@ -408,13 +443,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send( ...@@ -408,13 +443,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
auto task = PointToPoint( auto task = PointToPoint(
tensors, tensors,
[&](phi::DenseTensor& input, ncclComm_t comm, const gpuStream_t& stream, [&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) { int dst_rank) {
return platform::dynload::ncclSend( return platform::dynload::ncclSend(
input.data(), input.numel(), input.data(),
platform::ToNCCLDataType(input.dtype()), dst_rank, comm, stream); input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
}, },
dst_rank, CommType::SEND); dst_rank,
CommType::SEND);
return task; return task;
} }
...@@ -424,13 +466,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( ...@@ -424,13 +466,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
auto task = PointToPoint( auto task = PointToPoint(
tensors, tensors,
[&](phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream, [&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) { int src_rank) {
return platform::dynload::ncclRecv( return platform::dynload::ncclRecv(
output.data(), output.numel(), output.data(),
platform::ToNCCLDataType(output.dtype()), src_rank, comm, stream); output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
}, },
src_rank, CommType::RECV); src_rank,
CommType::RECV);
return task; return task;
} }
...@@ -448,13 +497,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial( ...@@ -448,13 +497,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
auto task = PointToPoint( auto task = PointToPoint(
shared_tensors, shared_tensors,
[&](phi::DenseTensor& input, ncclComm_t comm, const gpuStream_t& stream, [&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) { int dst_rank) {
return platform::dynload::ncclSend( return platform::dynload::ncclSend(
input.data(), input.numel(), input.data(),
platform::ToNCCLDataType(input.dtype()), dst_rank, comm, stream); input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
}, },
dst_rank, CommType::SEND); dst_rank,
CommType::SEND);
return task; return task;
} }
...@@ -471,13 +527,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial( ...@@ -471,13 +527,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
auto task = PointToPoint( auto task = PointToPoint(
shared_tensors, shared_tensors,
[&](phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream, [&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) { int src_rank) {
return platform::dynload::ncclRecv( return platform::dynload::ncclRecv(
output.data(), output.numel(), output.data(),
platform::ToNCCLDataType(output.dtype()), src_rank, comm, stream); output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
}, },
src_rank, CommType::RECV); src_rank,
CommType::RECV);
return task; return task;
} }
...@@ -485,23 +548,33 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( ...@@ -485,23 +548,33 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) { std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true, CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true, CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace.")); platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, out_tensors, in_tensors,
[&](const phi::DenseTensor& input, phi::DenseTensor& output, out_tensors,
ncclComm_t comm, const gpuStream_t& stream) { [&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather( return platform::dynload::ncclAllGather(
input.data(), output.data(), input.numel(), input.data(),
platform::ToNCCLDataType(input.dtype()), comm, stream); output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
}, },
CommType::ALLGATHER); CommType::ALLGATHER);
} }
void* GetPointerByOffset(void* raw_pointer, size_t offset, void* GetPointerByOffset(void* raw_pointer,
size_t offset,
experimental::DataType type) { experimental::DataType type) {
if (type == experimental::DataType::FLOAT32) { if (type == experimental::DataType::FLOAT32) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) + return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
...@@ -529,26 +602,37 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -529,26 +602,37 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) { std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true, CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true, CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, out_tensors, in_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, ncclComm_t comm, out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
size_t offset = 0; size_t offset = 0;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()), GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_, platform::ToNCCLDataType(input.dtype()), i, input.numel() / size_,
comm, stream)); platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output.data(), offset, input.dtype()), GetPointerByOffset(output.data(), offset, input.dtype()),
input.numel() / size_, platform::ToNCCLDataType(input.dtype()), i, input.numel() / size_,
comm, stream)); platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += input.numel() / size_; offset += input.numel() / size_;
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
...@@ -558,34 +642,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -558,34 +642,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const ReduceOptions& opts) { std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true, CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, out_tensors, in_tensors,
[&](const phi::DenseTensor& input, phi::DenseTensor& output, out_tensors,
ncclComm_t comm, const gpuStream_t& stream) { [&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
input.data(), output.data(), input.numel(), input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()), platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op), opts.root_rank, comm, stream)); ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
}, },
CommType::REDUCE); CommType::REDUCE);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const ScatterOptions& opts) { std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true, CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true, CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, out_tensors, in_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, ncclComm_t comm, out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
size_t offset = 0; size_t offset = 0;
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
...@@ -593,19 +693,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -593,19 +693,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()), GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_, platform::ToNCCLDataType(input.dtype()), input.numel() / size_,
i, comm, stream)); platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += input.numel() / size_; offset += input.numel() / size_;
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(), input.numel() / size_, output.data(),
platform::ToNCCLDataType(input.dtype()), opts.root_rank, comm, input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream)); stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(), input.numel() / size_, output.data(),
platform::ToNCCLDataType(input.dtype()), opts.root_rank, comm, input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream)); stream));
} }
}, },
......
...@@ -54,7 +54,9 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -54,7 +54,9 @@ class ProcessGroupNCCL : public ProcessGroup {
class NCCLTask : public ProcessGroup::Task, class NCCLTask : public ProcessGroup::Task,
public std::enable_shared_from_this<NCCLTask> { public std::enable_shared_from_this<NCCLTask> {
public: public:
NCCLTask(const std::vector<Place>& places, int rank, CommType CommType, NCCLTask(const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs); const std::vector<phi::DenseTensor>& inputs);
bool IsCompleted(); bool IsCompleted();
...@@ -80,8 +82,11 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -80,8 +82,11 @@ class ProcessGroupNCCL : public ProcessGroup {
private: private:
}; };
ProcessGroupNCCL(const std::shared_ptr<Store>& store, int rank, int size, ProcessGroupNCCL(const std::shared_ptr<Store>& store,
const platform::Place& place, int gid); int rank,
int size,
const platform::Place& place,
int gid);
const std::string GetBackendName() const override { const std::string GetBackendName() const override {
return std::string(NCCL_BACKEND_NAME); return std::string(NCCL_BACKEND_NAME);
...@@ -107,11 +112,13 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -107,11 +112,13 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<phi::DenseTensor>& tensors, int src_rank) override; std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors, std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors,
int dst_rank, int offset, int dst_rank,
int offset,
int length) override; int length) override;
std::shared_ptr<ProcessGroup::Task> Recv_Partial(phi::DenseTensor& tensors, std::shared_ptr<ProcessGroup::Task> Recv_Partial(phi::DenseTensor& tensors,
int src_rank, int offset, int src_rank,
int offset,
int length) override; int length) override;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
...@@ -134,7 +141,9 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -134,7 +141,9 @@ class ProcessGroupNCCL : public ProcessGroup {
protected: protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask( virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType, std::vector<Place> places,
int rank,
CommType opType,
const std::vector<phi::DenseTensor>& inputs); const std::vector<phi::DenseTensor>& inputs);
protected: protected:
...@@ -153,7 +162,8 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -153,7 +162,8 @@ class ProcessGroupNCCL : public ProcessGroup {
std::set<int> used_place_ids_; std::set<int> used_place_ids_;
private: private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root, // NOLINT void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids,
int root, // NOLINT
int server_fd); int server_fd);
void BroadcastUniqueNCCLID(std::vector<ncclUniqueId>& nccl_ids); // NOLINT void BroadcastUniqueNCCLID(std::vector<ncclUniqueId>& nccl_ids); // NOLINT
...@@ -162,16 +172,21 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -162,16 +172,21 @@ class ProcessGroupNCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Collective( std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn, CommType op_type); Fn fn,
CommType op_type);
template <typename Fn> template <typename Fn>
void Collective(const phi::DenseTensor*, phi::DenseTensor*, Fn fn, void Collective(const phi::DenseTensor*,
phi::DenseTensor*,
Fn fn,
CommType op_type); CommType op_type);
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint( std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<phi::DenseTensor>& tensors, // NOLINT std::vector<phi::DenseTensor>& tensors, // NOLINT
Fn fn, int dst_rank, CommType op_type); Fn fn,
int dst_rank,
CommType op_type);
void CreateNCCLManagerCache(const std::string& places_key, void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places); const std::vector<Place>& places);
......
...@@ -40,7 +40,8 @@ using IntArray = ...@@ -40,7 +40,8 @@ using IntArray =
using Backend = paddle::experimental::Backend; using Backend = paddle::experimental::Backend;
std::vector<std::vector<size_t>> Eager_AssignGroupBySize( std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
const std::vector<Tensor>, const std::vector<bool> &is_sparse_gradient, const std::vector<Tensor>,
const std::vector<bool> &is_sparse_gradient,
const std::vector<size_t> &group_size_limits, const std::vector<size_t> &group_size_limits,
const std::vector<int64_t> &tensor_indices = {}); const std::vector<int64_t> &tensor_indices = {});
......
...@@ -21,20 +21,27 @@ namespace distributed { ...@@ -21,20 +21,27 @@ namespace distributed {
// AfsClient impl // AfsClient impl
int AfsClient::initialize(const FsClientParameter& fs_client_param) { int AfsClient::initialize(const FsClientParameter& fs_client_param) {
// temporarily implemented with hdfs-client // temporarily implemented with hdfs-client
return initialize(fs_client_param.hadoop_bin(), fs_client_param.uri(), return initialize(fs_client_param.hadoop_bin(),
fs_client_param.user(), fs_client_param.passwd(), fs_client_param.uri(),
fs_client_param.user(),
fs_client_param.passwd(),
fs_client_param.buffer_size()); fs_client_param.buffer_size());
} }
int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri, int AfsClient::initialize(const std::string& hadoop_bin,
const std::string& user, const std::string& passwd, const std::string& uri,
const std::string& user,
const std::string& passwd,
int buffer_size_param) { int buffer_size_param) {
return initialize( return initialize(
hadoop_bin, uri, hadoop_bin,
uri,
paddle::string::format_string("%s,%s", user.c_str(), passwd.c_str()), paddle::string::format_string("%s,%s", user.c_str(), passwd.c_str()),
buffer_size_param); buffer_size_param);
} }
int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri, int AfsClient::initialize(const std::string& hadoop_bin,
const std::string& ugi, int buffer_size_param) { const std::string& uri,
const std::string& ugi,
int buffer_size_param) {
// temporarily implemented with hdfs-client // temporarily implemented with hdfs-client
size_t buffer_size = 1L << 25; // 32MB size_t buffer_size = 1L << 25; // 32MB
if (buffer_size_param > static_cast<int>(buffer_size)) { if (buffer_size_param > static_cast<int>(buffer_size)) {
...@@ -44,7 +51,9 @@ int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri, ...@@ -44,7 +51,9 @@ int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri,
paddle::framework::hdfs_set_command(paddle::string::format_string( paddle::framework::hdfs_set_command(paddle::string::format_string(
"2>>./hdfs_err.log %s fs -Dfs.default.name=%s -Dhadoop.job.ugi=%s " "2>>./hdfs_err.log %s fs -Dfs.default.name=%s -Dhadoop.job.ugi=%s "
"-Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=300000", "-Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=300000",
hadoop_bin.c_str(), uri.c_str(), ugi.c_str())); hadoop_bin.c_str(),
uri.c_str(),
ugi.c_str()));
return 0; return 0;
} }
......
...@@ -129,11 +129,15 @@ class AfsClient { ...@@ -129,11 +129,15 @@ class AfsClient {
AfsClient(const AfsClient&) = delete; AfsClient(const AfsClient&) = delete;
int initialize(const FsClientParameter& fs_client_param); int initialize(const FsClientParameter& fs_client_param);
int initialize(const std::string& hadoop_bin, const std::string& uri, int initialize(const std::string& hadoop_bin,
const std::string& user, const std::string& passwd, const std::string& uri,
const std::string& user,
const std::string& passwd,
int buffer_size_param = (1L << 25));
int initialize(const std::string& hadoop_bin,
const std::string& uri,
const std::string& ugi,
int buffer_size_param = (1L << 25)); int buffer_size_param = (1L << 25));
int initialize(const std::string& hadoop_bin, const std::string& uri,
const std::string& ugi, int buffer_size_param = (1L << 25));
// open file in 'w' or 'r' // open file in 'w' or 'r'
std::shared_ptr<FsReadChannel> open_r(const FsChannelConfig& config, std::shared_ptr<FsReadChannel> open_r(const FsChannelConfig& config,
......
...@@ -25,8 +25,10 @@ class TopkCalculator { ...@@ -25,8 +25,10 @@ class TopkCalculator {
_shard_max_size = _total_max_size / shard_num; _shard_max_size = _total_max_size / shard_num;
_shard_max_size = _shard_max_size > 1 ? _shard_max_size : 1; _shard_max_size = _shard_max_size > 1 ? _shard_max_size : 1;
for (int i = 0; i < shard_num; ++i) { for (int i = 0; i < shard_num; ++i) {
_mpq.emplace(i, std::priority_queue<double, std::vector<double>, _mpq.emplace(i,
std::greater<double>>()); std::priority_queue<double,
std::vector<double>,
std::greater<double>>());
} }
} }
~TopkCalculator() {} ~TopkCalculator() {}
...@@ -58,8 +60,9 @@ class TopkCalculator { ...@@ -58,8 +60,9 @@ class TopkCalculator {
} }
private: private:
std::unordered_map<int, std::priority_queue<double, std::vector<double>, std::unordered_map<
std::greater<double>>> int,
std::priority_queue<double, std::vector<double>, std::greater<double>>>
_mpq; _mpq;
int _shard_num; int _shard_num;
size_t _total_max_size; size_t _total_max_size;
......
...@@ -50,8 +50,10 @@ void Carrier::Init( ...@@ -50,8 +50,10 @@ void Carrier::Init(
int64_t rank, int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node, const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
const framework::ProgramDesc& program, framework::Scope* scope, const framework::ProgramDesc& program,
int64_t num_micro_batches, const platform::Place& place, framework::Scope* scope,
int64_t num_micro_batches,
const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars) { const std::vector<std::string>& inference_root_scope_vars) {
rank_ = rank; rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank; interceptor_id_to_rank_ = interceptor_id_to_rank;
...@@ -60,8 +62,9 @@ void Carrier::Init( ...@@ -60,8 +62,9 @@ void Carrier::Init(
root_scope_ = scope; root_scope_ = scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( PADDLE_ENFORCE_NOT_NULL(
"root_scope can not be nullptr")); root_scope_,
platform::errors::InvalidArgument("root_scope can not be nullptr"));
minibatch_scope_ = &root_scope_->NewScope(); minibatch_scope_ = &root_scope_->NewScope();
microbatch_scopes_.resize(num_micro_batches); microbatch_scopes_.resize(num_micro_batches);
for (int i = 0; i < num_micro_batches; ++i) { for (int i = 0; i < num_micro_batches; ++i) {
...@@ -87,7 +90,8 @@ void Carrier::Release() { ...@@ -87,7 +90,8 @@ void Carrier::Release() {
Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
void Carrier::CopyParameters( void Carrier::CopyParameters(
int microbatch_id, const framework::ProgramDesc& program, int microbatch_id,
const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars) { const std::vector<std::string>& inference_root_scope_vars) {
auto& global_block = program.Block(0); auto& global_block = program.Block(0);
...@@ -119,7 +123,8 @@ void Carrier::CopyParameters( ...@@ -119,7 +123,8 @@ void Carrier::CopyParameters(
bool Carrier::EnqueueInterceptorMessage( bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
interceptor_message.ctrl_message(), false, interceptor_message.ctrl_message(),
false,
platform::errors::Fatal( platform::errors::Fatal(
"Control message should be only send inter rank using message bus.")); "Control message should be only send inter rank using message bus."));
int64_t dst_id = interceptor_message.dst_id(); int64_t dst_id = interceptor_message.dst_id();
...@@ -130,7 +135,8 @@ bool Carrier::EnqueueInterceptorMessage( ...@@ -130,7 +135,8 @@ bool Carrier::EnqueueInterceptorMessage(
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(), PADDLE_ENFORCE_NE(iter,
interceptor_idx_to_interceptor_.end(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Cannot find interceptor instance for interceptor " "Cannot find interceptor instance for interceptor "
"id %lld. Wrong dst? Call before init?", "id %lld. Wrong dst? Call before init?",
...@@ -149,7 +155,8 @@ void Carrier::WakeUp() { ...@@ -149,7 +155,8 @@ void Carrier::WakeUp() {
} }
void Carrier::Start() { void Carrier::Start() {
PADDLE_ENFORCE_EQ(is_init_, true, PADDLE_ENFORCE_EQ(is_init_,
true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Using carrier before initialized.")); "Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) { for (int64_t id : source_interceptor_ids_) {
...@@ -199,10 +206,12 @@ bool Carrier::Send(const InterceptorMessage& msg) { ...@@ -199,10 +206,12 @@ bool Carrier::Send(const InterceptorMessage& msg) {
int64_t src_rank = GetRank(src_id); int64_t src_rank = GetRank(src_id);
int64_t dst_rank = GetRank(dst_id); int64_t dst_rank = GetRank(dst_id);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
src_rank, rank_, src_rank,
rank_,
platform::errors::Fatal("The source rank id %lld, which is not equal to " platform::errors::Fatal("The source rank id %lld, which is not equal to "
"the carrier rank id %lld.", "the carrier rank id %lld.",
src_rank, rank_)); src_rank,
rank_));
if (src_rank == dst_rank) { if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks."; << " to interceptor " << dst_id << ", which are in the same ranks.";
...@@ -218,7 +227,8 @@ bool Carrier::Send(const InterceptorMessage& msg) { ...@@ -218,7 +227,8 @@ bool Carrier::Send(const InterceptorMessage& msg) {
Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor> interceptor) { std::unique_ptr<Interceptor> interceptor) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(), PADDLE_ENFORCE_EQ(iter,
interceptor_idx_to_interceptor_.end(),
platform::errors::AlreadyExists( platform::errors::AlreadyExists(
"The interceptor id %lld has already been created! " "The interceptor id %lld has already been created! "
"The interceptor id should be unique.", "The interceptor id should be unique.",
...@@ -267,19 +277,22 @@ void Carrier::CreateInterceptors() { ...@@ -267,19 +277,22 @@ void Carrier::CreateInterceptors() {
TaskNode* task_node = item.second; TaskNode* task_node = item.second;
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(), task_node->run_at_offset(),
task_node->run_per_steps(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Interceptor's run_at_offset must < run_per_steps, must now " "Interceptor's run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld", "run_at_offset=%ld run_per_steps=%ld",
task_node->run_at_offset(), task_node->run_per_steps())); task_node->run_at_offset(),
task_node->run_per_steps()));
std::unique_ptr<Interceptor> interceptor; std::unique_ptr<Interceptor> interceptor;
PADDLE_ENFORCE_NE(task_node->type().empty(), true, PADDLE_ENFORCE_NE(task_node->type().empty(),
true,
platform::errors::NotFound( platform::errors::NotFound(
"Cannot found type for task node with id %lld", "Cannot found type for task node with id %lld",
task_node->task_id())); task_node->task_id()));
interceptor = InterceptorFactory::Create(task_node->type(), interceptor_id, interceptor = InterceptorFactory::Create(
task_node); task_node->type(), interceptor_id, task_node);
interceptor->SetPlace(place_); interceptor->SetPlace(place_);
interceptor->SetMiniBatchScope(minibatch_scope_); interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_); interceptor->SetMicroBatchScope(microbatch_scopes_);
......
...@@ -56,12 +56,15 @@ class Carrier final { ...@@ -56,12 +56,15 @@ class Carrier final {
int64_t rank, int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node, const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
const framework::ProgramDesc& program, framework::Scope* scope, const framework::ProgramDesc& program,
int64_t num_micro_batches, const platform::Place& place, framework::Scope* scope,
int64_t num_micro_batches,
const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars = {}); const std::vector<std::string>& inference_root_scope_vars = {});
void CopyParameters( void CopyParameters(
int microbatch_id, const framework::ProgramDesc& program, int microbatch_id,
const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars); const std::vector<std::string>& inference_root_scope_vars);
void Release(); void Release();
......
...@@ -43,7 +43,8 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -43,7 +43,8 @@ void ComputeInterceptor::PrepareDeps() {
// source compute node, should we add a new SourceInterceptor? // source compute node, should we add a new SourceInterceptor?
if (upstream.empty()) { if (upstream.empty()) {
is_source_ = true; is_source_ = true;
PADDLE_ENFORCE_GT(node_->max_run_times(), 0, PADDLE_ENFORCE_GT(node_->max_run_times(),
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Source ComputeInterceptor must run at least one " "Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld", "times, but now max_run_times=%ld",
...@@ -60,7 +61,8 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -60,7 +61,8 @@ void ComputeInterceptor::PrepareDeps() {
void ComputeInterceptor::IncreaseReady(int64_t up_id) { void ComputeInterceptor::IncreaseReady(int64_t up_id) {
auto it = in_readys_.find(up_id); auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, in_readys_.end(), PADDLE_ENFORCE_NE(it,
in_readys_.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id)); "Cannot find upstream=%lld in in_readys.", up_id));
...@@ -73,26 +75,32 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { ...@@ -73,26 +75,32 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
auto max_ready_size = it->second.first; auto max_ready_size = it->second.first;
auto ready_size = it->second.second; auto ready_size = it->second.second;
ready_size += 1; ready_size += 1;
PADDLE_ENFORCE_LE(ready_size, max_ready_size, PADDLE_ENFORCE_LE(ready_size,
max_ready_size,
platform::errors::OutOfRange( platform::errors::OutOfRange(
"upstream=%lld ready_size must <= max_ready_size, but " "upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld", "now ready_size=%lld, max_ready_size=%lld",
up_id, ready_size, max_ready_size)); up_id,
ready_size,
max_ready_size));
it->second.second = ready_size; it->second.second = ready_size;
} }
void ComputeInterceptor::DecreaseBuff(int64_t down_id) { void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
auto it = out_buffs_.find(down_id); auto it = out_buffs_.find(down_id);
PADDLE_ENFORCE_NE(it, out_buffs_.end(), PADDLE_ENFORCE_NE(it,
out_buffs_.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Cannot find downstream=%lld in out_buffs.", down_id)); "Cannot find downstream=%lld in out_buffs.", down_id));
auto used_size = it->second.second; auto used_size = it->second.second;
used_size -= 1; used_size -= 1;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
used_size, 0, used_size,
0,
platform::errors::OutOfRange( platform::errors::OutOfRange(
"downstream=%lld used buff size must >= 0, but now equal %lld", "downstream=%lld used buff size must >= 0, but now equal %lld",
down_id, used_size)); down_id,
used_size));
it->second.second = used_size; it->second.second = used_size;
} }
...@@ -130,11 +138,14 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -130,11 +138,14 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
auto used_size = outs.second.second; auto used_size = outs.second.second;
used_size += 1; used_size += 1;
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
used_size, max_buff_size, used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= " platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, " "max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld", "max_buff_size=%lld",
down_id, used_size, max_buff_size)); down_id,
used_size,
max_buff_size));
outs.second.second = used_size; outs.second.second = used_size;
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
...@@ -152,9 +163,11 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -152,9 +163,11 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
auto ready_size = ins.second.second; auto ready_size = ins.second.second;
ready_size -= 1; ready_size -= 1;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
ready_size, 0, ready_size,
0,
platform::errors::OutOfRange( platform::errors::OutOfRange(
"upstream=%lld ready_size must >= 0, but now got %lld", up_id, "upstream=%lld ready_size must >= 0, but now got %lld",
up_id,
ready_size)); ready_size));
ins.second.second = ready_size; ins.second.second = ready_size;
...@@ -176,8 +189,10 @@ void ComputeInterceptor::RunOps() { ...@@ -176,8 +189,10 @@ void ComputeInterceptor::RunOps() {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_); op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
if (gc_) { if (gc_) {
framework::DeleteUnusedTensors( framework::DeleteUnusedTensors(
*microbatch_scopes_[step_ % node_->max_run_times()], op, *microbatch_scopes_[step_ % node_->max_run_times()],
node_->unused_vars(), gc_.get()); op,
node_->unused_vars(),
gc_.get());
} }
} }
} }
...@@ -210,11 +225,13 @@ void ComputeInterceptor::ReceivedStop(int64_t up_id) { ...@@ -210,11 +225,13 @@ void ComputeInterceptor::ReceivedStop(int64_t up_id) {
if (is_source_ && up_id == -1) return; if (is_source_ && up_id == -1) return;
auto it = in_stops_.find(up_id); auto it = in_stops_.find(up_id);
PADDLE_ENFORCE_NE(it, in_stops_.end(), PADDLE_ENFORCE_NE(it,
in_stops_.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Cannot find upstream=%lld in in_stops.", up_id)); "Cannot find upstream=%lld in in_stops.", up_id));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
it->second, false, it->second,
false,
platform::errors::AlreadyExists("Already received stop from %lld, stop " platform::errors::AlreadyExists("Already received stop from %lld, stop "
"cannot be send more than once.")); "cannot be send more than once."));
it->second = true; it->second = true;
......
...@@ -71,7 +71,8 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, ...@@ -71,7 +71,8 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
VLOG(3) << "Loading data for CPU."; VLOG(3) << "Loading data for CPU.";
std::memcpy(static_cast<void *>(input_tensor_ptr), input_data.data.data(), std::memcpy(static_cast<void *>(input_tensor_ptr),
input_data.data.data(),
input_data.data.length()); input_data.data.length());
} else if (platform::is_gpu_place(place)) { } else if (platform::is_gpu_place(place)) {
VLOG(3) << "Loading data for GPU."; VLOG(3) << "Loading data for GPU.";
...@@ -80,9 +81,12 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, ...@@ -80,9 +81,12 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
auto *dev_ctx = auto *dev_ctx =
dynamic_cast<const platform::CUDADeviceContext *>(pool.Get(place)); dynamic_cast<const platform::CUDADeviceContext *>(pool.Get(place));
auto gpu_place = place; auto gpu_place = place;
memory::Copy(gpu_place, static_cast<void *>(input_tensor_ptr), memory::Copy(gpu_place,
platform::CPUPlace(), input_data.data.data(), static_cast<void *>(input_tensor_ptr),
input_data.data.length(), dev_ctx->stream()); platform::CPUPlace(),
input_data.data.data(),
input_data.data.length(),
dev_ctx->stream());
#else #else
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Paddle wasn't compiled with CUDA, but place is GPU.")); "Paddle wasn't compiled with CUDA, but place is GPU."));
...@@ -139,15 +143,17 @@ class DistModelTimer { ...@@ -139,15 +143,17 @@ class DistModelTimer {
bool DistModel::Init() { bool DistModel::Init() {
carrier_id_ = "inference"; carrier_id_ = "inference";
bool init_method = (!config_.model_dir.empty() || config_.program_desc); bool init_method = (!config_.model_dir.empty() || config_.program_desc);
PADDLE_ENFORCE_EQ(init_method, true, PADDLE_ENFORCE_EQ(init_method,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"One of model dir or program desc must be provided to " "One of model dir or program desc must be provided to "
"dist model inference.")); "dist model inference."));
if (config_.program_desc) { if (config_.program_desc) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
config_.scope, platform::errors::InvalidArgument( config_.scope,
"Scope must be provided to dist model inference if " platform::errors::InvalidArgument(
"program desc has been provided.")); "Scope must be provided to dist model inference if "
"program desc has been provided."));
} }
if (!PreparePlace()) { if (!PreparePlace()) {
return false; return false;
...@@ -217,8 +223,12 @@ bool DistModel::CommInit() { ...@@ -217,8 +223,12 @@ bool DistModel::CommInit() {
} }
peer_endpoints.emplace_back(config_.trainer_endpoints[rank]); peer_endpoints.emplace_back(config_.trainer_endpoints[rank]);
} }
InsertCommOp(var_name_base + std::to_string(order), ranks_in_group, InsertCommOp(var_name_base + std::to_string(order),
rank_in_group, peer_endpoints, comm_init_block, ring_id); ranks_in_group,
rank_in_group,
peer_endpoints,
comm_init_block,
ring_id);
order += 1; order += 1;
} }
framework::NaiveExecutor e(place_); framework::NaiveExecutor e(place_);
...@@ -229,9 +239,12 @@ bool DistModel::CommInit() { ...@@ -229,9 +239,12 @@ bool DistModel::CommInit() {
return true; return true;
} }
void DistModel::InsertCommOp(std::string tmp_var_name, int nranks, int rank, void DistModel::InsertCommOp(std::string tmp_var_name,
int nranks,
int rank,
const std::vector<std::string> &peer_endpoints, const std::vector<std::string> &peer_endpoints,
framework::BlockDesc *block, int ring_id) { framework::BlockDesc *block,
int ring_id) {
/* /*
* tmp_var_name: the var name for var comm_id * tmp_var_name: the var name for var comm_id
* nranks: number of total ranks * nranks: number of total ranks
...@@ -297,7 +310,8 @@ bool DistModel::PrepareProgram() { ...@@ -297,7 +310,8 @@ bool DistModel::PrepareProgram() {
bool DistModel::LoadProgram() { bool DistModel::LoadProgram() {
VLOG(3) << "Loading program from " << config_.model_dir; VLOG(3) << "Loading program from " << config_.model_dir;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
config_.model_dir, "", config_.model_dir,
"",
platform::errors::InvalidArgument("Model dir must be provided.")); platform::errors::InvalidArgument("Model dir must be provided."));
std::string model_path = config_.model_dir + ".pdmodel"; std::string model_path = config_.model_dir + ".pdmodel";
framework::proto::ProgramDesc program_proto; framework::proto::ProgramDesc program_proto;
...@@ -305,7 +319,8 @@ bool DistModel::LoadProgram() { ...@@ -305,7 +319,8 @@ bool DistModel::LoadProgram() {
// Read binary // Read binary
std::ifstream fin(model_path, std::ios::in | std::ios::binary); std::ifstream fin(model_path, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<bool>(fin.is_open()), true, static_cast<bool>(fin.is_open()),
true,
platform::errors::NotFound( platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file is normal.", "Cannot open file %s, please confirm whether the file is normal.",
model_path)); model_path));
...@@ -387,8 +402,13 @@ bool DistModel::PrepareFleetExe() { ...@@ -387,8 +402,13 @@ bool DistModel::PrepareFleetExe() {
id_to_rank.insert({i, i}); id_to_rank.insert({i, i});
} }
fleet_exe.reset(new FleetExecutor(executor_desc_)); fleet_exe.reset(new FleetExecutor(executor_desc_));
fleet_exe->Init(carrier_id_, *(program_.get()), scope_.get(), place_, 1, fleet_exe->Init(carrier_id_,
{task_node_.get()}, id_to_rank); *(program_.get()),
scope_.get(),
place_,
1,
{task_node_.get()},
id_to_rank);
return true; return true;
} }
...@@ -490,9 +510,11 @@ bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data, ...@@ -490,9 +510,11 @@ bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data,
int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col")); int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col"));
VLOG(3) << "Fetching data for [" << idx_to_fetches_[idx] << "]"; VLOG(3) << "Fetching data for [" << idx_to_fetches_[idx] << "]";
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<size_t>(idx), i, static_cast<size_t>(idx),
i,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Fetch op's col attr(%d) should be equal to the index(%d)", idx, "Fetch op's col attr(%d) should be equal to the index(%d)",
idx,
i)); i));
framework::FetchType &fetch_var = framework::FetchType &fetch_var =
framework::GetFetchVariable(*scope, "fetch", idx); framework::GetFetchVariable(*scope, "fetch", idx);
......
...@@ -72,9 +72,12 @@ class DistModel { ...@@ -72,9 +72,12 @@ class DistModel {
bool CommInit(); bool CommInit();
bool PrepareFeedAndFetch(); bool PrepareFeedAndFetch();
bool PrepareFleetExe(); bool PrepareFleetExe();
void InsertCommOp(std::string tmp_var_name, int nranks, int rank, void InsertCommOp(std::string tmp_var_name,
int nranks,
int rank,
const std::vector<std::string>& peer_endpoints, const std::vector<std::string>& peer_endpoints,
framework::BlockDesc* block, int ring_id); framework::BlockDesc* block,
int ring_id);
bool FeedData(const std::vector<DistModelTensor>& input_data, bool FeedData(const std::vector<DistModelTensor>& input_data,
framework::Scope* scope); framework::Scope* scope);
bool FetchResults(std::vector<DistModelTensor>* output_data, bool FetchResults(std::vector<DistModelTensor>* output_data,
......
...@@ -28,7 +28,8 @@ void DistModelDataBuf::Reset(void* data, size_t length) { ...@@ -28,7 +28,8 @@ void DistModelDataBuf::Reset(void* data, size_t length) {
void DistModelDataBuf::Free() { void DistModelDataBuf::Free() {
if (memory_owned_ && data_) { if (memory_owned_ && data_) {
PADDLE_ENFORCE_GT(length_, 0UL, PADDLE_ENFORCE_GT(length_,
0UL,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Error occurred when deconstruct DistModelDataBuf: " "Error occurred when deconstruct DistModelDataBuf: "
"it contains no data!")); "it contains no data!"));
......
...@@ -30,8 +30,9 @@ namespace distributed { ...@@ -30,8 +30,9 @@ namespace distributed {
FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str); bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet( PADDLE_ENFORCE(parse_flag,
"Error occurs while parsing string to proto")); platform::errors::PreconditionNotMet(
"Error occurs while parsing string to proto"));
// Message bus will be created and inited only once // Message bus will be created and inited only once
GlobalVal<MessageBus>::Create(); GlobalVal<MessageBus>::Create();
InitMessageBus(); InitMessageBus();
...@@ -51,12 +52,16 @@ FleetExecutor::~FleetExecutor() { ...@@ -51,12 +52,16 @@ FleetExecutor::~FleetExecutor() {
} }
void FleetExecutor::Init( void FleetExecutor::Init(
const std::string& carrier_id, const framework::ProgramDesc& program_desc, const std::string& carrier_id,
framework::Scope* scope, const platform::Place& place, const framework::ProgramDesc& program_desc,
int64_t num_micro_batches, const std::vector<TaskNode*>& task_nodes, framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank, const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars) { const std::vector<std::string>& inference_root_scope_vars) {
PADDLE_ENFORCE_GT(task_nodes.size(), 0, PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node")); "Fleet executor is inited with empty task node"));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph // TODO(fleet_exe devs): the unused_vars should be got from run time graph
...@@ -116,18 +121,30 @@ void FleetExecutor::Init( ...@@ -116,18 +121,30 @@ void FleetExecutor::Init(
carrier_ids_.insert(carrier_id); carrier_ids_.insert(carrier_id);
// Set current running carrier // Set current running carrier
GlobalVal<std::string>::Set(new std::string(carrier_id)); GlobalVal<std::string>::Set(new std::string(carrier_id));
InitCarrier(carrier, scope, place, num_micro_batches, program_desc, InitCarrier(carrier,
scope,
place,
num_micro_batches,
program_desc,
inference_root_scope_vars); inference_root_scope_vars);
GlobalVal<MessageBus>::Get()->Barrier(); GlobalVal<MessageBus>::Get()->Barrier();
} }
void FleetExecutor::InitCarrier( void FleetExecutor::InitCarrier(
Carrier* carrier, framework::Scope* scope, const platform::Place& place, Carrier* carrier,
int64_t num_micro_batches, const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const framework::ProgramDesc& program_desc,
const std::vector<std::string>& inference_root_scope_vars) { const std::vector<std::string>& inference_root_scope_vars) {
carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(), carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_node(), program_desc, scope, runtime_graph_->interceptor_id_to_rank(),
num_micro_batches, place, inference_root_scope_vars); runtime_graph_->interceptor_id_to_node(),
program_desc,
scope,
num_micro_batches,
place,
inference_root_scope_vars);
} }
void FleetExecutor::InitMessageBus() { void FleetExecutor::InitMessageBus() {
...@@ -148,11 +165,13 @@ void FleetExecutor::InitMessageBus() { ...@@ -148,11 +165,13 @@ void FleetExecutor::InitMessageBus() {
} }
if (addr == "") { if (addr == "") {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rank_to_addr.size(), 1, rank_to_addr.size(),
1,
platform::errors::NotFound("Empty address is not valid for " platform::errors::NotFound("Empty address is not valid for "
"paddle.distributed.launch method.")); "paddle.distributed.launch method."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
cur_rank, 0, cur_rank,
0,
platform::errors::NotFound("Address is empty but cur rank is not 0.")); platform::errors::NotFound("Address is empty but cur rank is not 0."));
} }
VLOG(3) << "Current rank is " << cur_rank << " and the ip_port is " VLOG(3) << "Current rank is " << cur_rank << " and the ip_port is "
......
...@@ -39,8 +39,10 @@ class FleetExecutor final { ...@@ -39,8 +39,10 @@ class FleetExecutor final {
explicit FleetExecutor(const FleetExecutorDesc& exe_desc); explicit FleetExecutor(const FleetExecutorDesc& exe_desc);
~FleetExecutor(); ~FleetExecutor();
void Init(const std::string& carrier_id, void Init(const std::string& carrier_id,
const framework::ProgramDesc& program_desc, framework::Scope* scope, const framework::ProgramDesc& program_desc,
const platform::Place& place, int64_t num_micro_batches, framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes, const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank, const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars = {}); const std::vector<std::string>& inference_root_scope_vars = {});
...@@ -50,8 +52,11 @@ class FleetExecutor final { ...@@ -50,8 +52,11 @@ class FleetExecutor final {
DISABLE_COPY_AND_ASSIGN(FleetExecutor); DISABLE_COPY_AND_ASSIGN(FleetExecutor);
void InitMessageBus(); void InitMessageBus();
void InitCarrier( void InitCarrier(
Carrier* carrier, framework::Scope* scope, const platform::Place& place, Carrier* carrier,
int64_t num_micro_batches, const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const framework::ProgramDesc& program_desc,
const std::vector<std::string>& inference_root_scope_vars = {}); const std::vector<std::string>& inference_root_scope_vars = {});
FleetExecutorDesc exe_desc_; FleetExecutorDesc exe_desc_;
std::shared_ptr<RuntimeGraph> runtime_graph_; std::shared_ptr<RuntimeGraph> runtime_graph_;
......
...@@ -31,7 +31,8 @@ class GlobalVal final { ...@@ -31,7 +31,8 @@ class GlobalVal final {
template <typename... Args> template <typename... Args>
static T* Create(Args&&... args) { static T* Create(Args&&... args) {
auto* ptr = GetPPtr(); auto* ptr = GetPPtr();
PADDLE_ENFORCE_EQ(ptr->get(), nullptr, PADDLE_ENFORCE_EQ(ptr->get(),
nullptr,
platform::errors::AlreadyExists( platform::errors::AlreadyExists(
"This value is already a global value.")); "This value is already a global value."));
T* item = new T(std::forward<Args>(args)...); T* item = new T(std::forward<Args>(args)...);
...@@ -65,7 +66,8 @@ class GlobalMap final { ...@@ -65,7 +66,8 @@ class GlobalMap final {
template <typename... Args> template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) { static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id); auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr, PADDLE_ENFORCE_EQ(ptr->get(),
nullptr,
platform::errors::AlreadyExists( platform::errors::AlreadyExists(
"This value has already in global map.")); "This value has already in global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...); ValueT* item = new ValueT(std::forward<Args>(args)...);
...@@ -86,14 +88,16 @@ class ThreadSafeGlobalMap final { ...@@ -86,14 +88,16 @@ class ThreadSafeGlobalMap final {
static ValueT* Get(KeyT id) { static ValueT* Get(KeyT id) {
ValueT* item = GetPPtr(id)->get(); ValueT* item = GetPPtr(id)->get();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
item, platform::errors::NotFound( item,
"This value is not in thread safe global map.")); platform::errors::NotFound(
"This value is not in thread safe global map."));
return item; return item;
} }
template <typename... Args> template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) { static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id); auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr, PADDLE_ENFORCE_EQ(ptr->get(),
nullptr,
platform::errors::AlreadyExists( platform::errors::AlreadyExists(
"This value has already in thread safe global map.")); "This value has already in thread safe global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...); ValueT* item = new ValueT(std::forward<Args>(args)...);
......
...@@ -35,8 +35,9 @@ Interceptor::~Interceptor() { ...@@ -35,8 +35,9 @@ Interceptor::~Interceptor() {
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
void Interceptor::Handle(const InterceptorMessage& msg) { void Interceptor::Handle(const InterceptorMessage& msg) {
PADDLE_ENFORCE_NOT_NULL(handle_, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(handle_,
"Message handle is not registered.")); platform::errors::PreconditionNotMet(
"Message handle is not registered."));
handle_(msg); handle_(msg);
} }
...@@ -46,7 +47,8 @@ void Interceptor::LoopOnce() { ...@@ -46,7 +47,8 @@ void Interceptor::LoopOnce() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
messages_.swap(tmp_messages); messages_.swap(tmp_messages);
} }
PADDLE_ENFORCE_EQ(tmp_messages.empty(), false, PADDLE_ENFORCE_EQ(tmp_messages.empty(),
false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"tmp_messages must not empty in task loop")); "tmp_messages must not empty in task loop"));
...@@ -61,8 +63,9 @@ void Interceptor::LoopOnce() { ...@@ -61,8 +63,9 @@ void Interceptor::LoopOnce() {
} }
void Interceptor::StopCarrier() { void Interceptor::StopCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(
"Carrier is not registered.")); carrier_,
platform::errors::PreconditionNotMet("Carrier is not registered."));
carrier_->WakeUp(); carrier_->WakeUp();
} }
...@@ -84,8 +87,9 @@ void Interceptor::EnqueueRemoteInterceptorMessage( ...@@ -84,8 +87,9 @@ void Interceptor::EnqueueRemoteInterceptorMessage(
} }
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(
"Carrier is not registered.")); carrier_,
platform::errors::PreconditionNotMet("Carrier is not registered."));
msg.set_src_id(interceptor_id_); msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id); msg.set_dst_id(dst_id);
return carrier_->Send(msg); return carrier_->Send(msg);
...@@ -102,7 +106,8 @@ std::unique_ptr<Interceptor> InterceptorFactory::Create(const std::string& type, ...@@ -102,7 +106,8 @@ std::unique_ptr<Interceptor> InterceptorFactory::Create(const std::string& type,
auto& interceptor_map = GetInterceptorMap(); auto& interceptor_map = GetInterceptorMap();
auto iter = interceptor_map.find(type); auto iter = interceptor_map.find(type);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, interceptor_map.end(), iter,
interceptor_map.end(),
platform::errors::NotFound("interceptor %s is not register", type)); platform::errors::NotFound("interceptor %s is not register", type));
return iter->second(id, node); return iter->second(id, node);
} }
......
...@@ -129,7 +129,8 @@ class InterceptorFactory { ...@@ -129,7 +129,8 @@ class InterceptorFactory {
static void Register(const std::string& type, CreateInterceptorFunc func); static void Register(const std::string& type, CreateInterceptorFunc func);
static std::unique_ptr<Interceptor> Create(const std::string& type, static std::unique_ptr<Interceptor> Create(const std::string& type,
int64_t id, TaskNode* node); int64_t id,
TaskNode* node);
}; };
template <typename InterceptorClass> template <typename InterceptorClass>
......
...@@ -27,10 +27,12 @@ namespace paddle { ...@@ -27,10 +27,12 @@ namespace paddle {
namespace distributed { namespace distributed {
void MessageBus::Init( void MessageBus::Init(
int64_t rank, const std::unordered_map<int64_t, std::string>& rank_to_addr, int64_t rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr) { const std::string& addr) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
is_init_, false, is_init_,
false,
platform::errors::AlreadyExists("MessageBus is already init.")); platform::errors::AlreadyExists("MessageBus is already init."));
rank_ = rank; rank_ = rank;
is_init_ = true; is_init_ = true;
...@@ -39,12 +41,14 @@ void MessageBus::Init( ...@@ -39,12 +41,14 @@ void MessageBus::Init(
if (addr_ != "") { if (addr_ != "") {
const auto& addr = GetAddr(rank_); const auto& addr = GetAddr(rank_);
PADDLE_ENFORCE_EQ(addr, addr_, PADDLE_ENFORCE_EQ(addr,
addr_,
platform::errors::Fatal( platform::errors::Fatal(
"The current rank's addr is %s, while the " "The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. " "message bus's addr is %s, which are different. "
"Init error.", "Init error.",
addr, addr_)); addr,
addr_));
} }
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
...@@ -77,7 +81,8 @@ MessageBus::~MessageBus() { ...@@ -77,7 +81,8 @@ MessageBus::~MessageBus() {
const std::string& MessageBus::GetAddr(int64_t rank) const { const std::string& MessageBus::GetAddr(int64_t rank) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
rank_to_addr_.find(rank), rank_to_addr_.end(), rank_to_addr_.find(rank),
rank_to_addr_.end(),
platform::errors::NotFound("Cannot find addr rank id %lld.", rank)); platform::errors::NotFound("Cannot find addr rank id %lld.", rank));
return rank_to_addr_.at(rank); return rank_to_addr_.at(rank);
} }
...@@ -85,7 +90,8 @@ const std::string& MessageBus::GetAddr(int64_t rank) const { ...@@ -85,7 +90,8 @@ const std::string& MessageBus::GetAddr(int64_t rank) const {
bool MessageBus::Send(int64_t dst_rank, bool MessageBus::Send(int64_t dst_rank,
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
IsInit(), true, IsInit(),
true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized.")); "Using message bus since it has not been initialized."));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
...@@ -176,7 +182,8 @@ void MessageBus::ListenPort() { ...@@ -176,7 +182,8 @@ void MessageBus::ListenPort() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
// function keep listen the port and handle the message // function keep listen the port and handle the message
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0, server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE),
0,
platform::errors::Unavailable("Message bus: init brpc service error.")); platform::errors::Unavailable("Message bus: init brpc service error."));
// start the server // start the server
...@@ -215,7 +222,8 @@ bool MessageBus::SendInterRank(int64_t dst_rank, ...@@ -215,7 +222,8 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
options.timeout_ms = 1000; options.timeout_ms = 1000;
options.max_retry = 5; options.max_retry = 5;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
channel.Init(dst_addr_for_brpc, &options), 0, channel.Init(dst_addr_for_brpc, &options),
0,
platform::errors::Unavailable("Message bus: init brpc channel error.")); platform::errors::Unavailable("Message bus: init brpc channel error."));
MessageService_Stub stub(&channel); MessageService_Stub stub(&channel);
InterceptorResponse response; InterceptorResponse response;
...@@ -224,8 +232,8 @@ bool MessageBus::SendInterRank(int64_t dst_rank, ...@@ -224,8 +232,8 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
if (interceptor_message.ctrl_message()) { if (interceptor_message.ctrl_message()) {
stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL); stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL);
} else { } else {
stub.ReceiveInterceptorMessage(&ctrl, &interceptor_message, &response, stub.ReceiveInterceptorMessage(
NULL); &ctrl, &interceptor_message, &response, NULL);
} }
if (!ctrl.Failed()) { if (!ctrl.Failed()) {
if (response.rst()) { if (response.rst()) {
......
...@@ -23,7 +23,8 @@ namespace distributed { ...@@ -23,7 +23,8 @@ namespace distributed {
void MessageServiceImpl::ReceiveInterceptorMessage( void MessageServiceImpl::ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base, google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response, const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done) { google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
VLOG(3) << "Message Service receives a message from interceptor " VLOG(3) << "Message Service receives a message from interceptor "
...@@ -35,7 +36,8 @@ void MessageServiceImpl::ReceiveInterceptorMessage( ...@@ -35,7 +36,8 @@ void MessageServiceImpl::ReceiveInterceptorMessage(
void MessageServiceImpl::IncreaseBarrierCount( void MessageServiceImpl::IncreaseBarrierCount(
google::protobuf::RpcController* control_base, google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response, const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done) { google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
VLOG(3) << "Barrier Service receives a message from rank " VLOG(3) << "Barrier Service receives a message from rank "
......
...@@ -26,11 +26,13 @@ class MessageServiceImpl : public MessageService { ...@@ -26,11 +26,13 @@ class MessageServiceImpl : public MessageService {
virtual ~MessageServiceImpl() {} virtual ~MessageServiceImpl() {}
virtual void ReceiveInterceptorMessage( virtual void ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base, google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response, const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done); google::protobuf::Closure* done);
virtual void IncreaseBarrierCount( virtual void IncreaseBarrierCount(
google::protobuf::RpcController* control_base, google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response, const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done); google::protobuf::Closure* done);
}; };
......
...@@ -27,7 +27,8 @@ TaskLoop* TaskLoop::GetTaskLoopOfCurrentThread() { return thread_local_loop_; } ...@@ -27,7 +27,8 @@ TaskLoop* TaskLoop::GetTaskLoopOfCurrentThread() { return thread_local_loop_; }
TaskLoop::TaskLoop() TaskLoop::TaskLoop()
: looping_(false), quit_(false), thread_id_(std::this_thread::get_id()) { : looping_(false), quit_(false), thread_id_(std::this_thread::get_id()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
thread_local_loop_, nullptr, thread_local_loop_,
nullptr,
platform::errors::AlreadyExists("Another TaskLoop is already init.")); platform::errors::AlreadyExists("Another TaskLoop is already init."));
thread_local_loop_ = this; thread_local_loop_ = this;
} }
...@@ -35,7 +36,8 @@ TaskLoop::TaskLoop() ...@@ -35,7 +36,8 @@ TaskLoop::TaskLoop()
TaskLoop::~TaskLoop() { thread_local_loop_ = nullptr; } TaskLoop::~TaskLoop() { thread_local_loop_ = nullptr; }
void TaskLoop::Loop() { void TaskLoop::Loop() {
PADDLE_ENFORCE_EQ(looping_, false, PADDLE_ENFORCE_EQ(looping_,
false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Loop can only execute in one loop thread")); "Loop can only execute in one loop thread"));
AssertInLoopThread(); AssertInLoopThread();
...@@ -75,7 +77,8 @@ void TaskLoop::WakeUp() { ...@@ -75,7 +77,8 @@ void TaskLoop::WakeUp() {
void TaskLoop::AbortNotInLoopThread() { void TaskLoop::AbortNotInLoopThread() {
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"This TaskLoop was created in thread %d, but current thread is %d", "This TaskLoop was created in thread %d, but current thread is %d",
thread_id_, std::this_thread::get_id())); thread_id_,
std::this_thread::get_id()));
} }
} // namespace distributed } // namespace distributed
......
...@@ -32,7 +32,8 @@ TaskLoopThread::~TaskLoopThread() { ...@@ -32,7 +32,8 @@ TaskLoopThread::~TaskLoopThread() {
TaskLoop* TaskLoopThread::StartLoop() { TaskLoop* TaskLoopThread::StartLoop() {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
start_, false, start_,
false,
platform::errors::PreconditionNotMet("thread is already running.")); platform::errors::PreconditionNotMet("thread is already running."));
start_ = true; start_ = true;
thread_ = std::thread([this]() { Loop(); }); thread_ = std::thread([this]() { Loop(); });
......
...@@ -31,10 +31,12 @@ TaskLoopThreadPool::~TaskLoopThreadPool() = default; ...@@ -31,10 +31,12 @@ TaskLoopThreadPool::~TaskLoopThreadPool() = default;
void TaskLoopThreadPool::Start() { void TaskLoopThreadPool::Start() {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
start_, false, start_,
false,
platform::errors::PreconditionNotMet("thread pool is already start.")); platform::errors::PreconditionNotMet("thread pool is already start."));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
thread_num_, 0, thread_num_,
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"thread num must greater than 0, but now is %d", thread_num_)); "thread num must greater than 0, but now is %d", thread_num_));
...@@ -47,21 +49,26 @@ void TaskLoopThreadPool::Start() { ...@@ -47,21 +49,26 @@ void TaskLoopThreadPool::Start() {
TaskLoop* TaskLoopThreadPool::GetLoop(int tid) { TaskLoop* TaskLoopThreadPool::GetLoop(int tid) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
start_, true, start_,
true,
platform::errors::PreconditionNotMet("thread pool must start first.")); platform::errors::PreconditionNotMet("thread pool must start first."));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
tid, 0, tid,
0,
platform::errors::OutOfRange("tid must >= 0, but now is %d", tid)); platform::errors::OutOfRange("tid must >= 0, but now is %d", tid));
PADDLE_ENFORCE_LT(tid, thread_num_, PADDLE_ENFORCE_LT(tid,
thread_num_,
platform::errors::OutOfRange( platform::errors::OutOfRange(
"tid must < thread_num, but now tid=%d thread_num=%d", "tid must < thread_num, but now tid=%d thread_num=%d",
tid, thread_num_)); tid,
thread_num_));
return loops_[tid]; return loops_[tid];
} }
std::vector<TaskLoop*> TaskLoopThreadPool::GetAllLoops() { std::vector<TaskLoop*> TaskLoopThreadPool::GetAllLoops() {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
start_, true, start_,
true,
platform::errors::PreconditionNotMet("thread pool must start first.")); platform::errors::PreconditionNotMet("thread pool must start first."));
return loops_; return loops_;
} }
......
...@@ -24,8 +24,10 @@ namespace { ...@@ -24,8 +24,10 @@ namespace {
using OperatorBase = TaskNode::OperatorBase; using OperatorBase = TaskNode::OperatorBase;
} }
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank, TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t max_run_times, int64_t max_slot_nums) int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program), : program_(program),
rank_(rank), rank_(rank),
max_run_times_(max_run_times), max_run_times_(max_run_times),
...@@ -80,7 +82,9 @@ TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times) ...@@ -80,7 +82,9 @@ TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times)
TaskNode::TaskNode(int32_t role, TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs, const std::vector<framework::OpDesc*>& op_descs,
int64_t rank, int64_t task_id, int64_t max_run_times, int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums) int64_t max_slot_nums)
: role_(role), : role_(role),
rank_(rank), rank_(rank),
...@@ -101,7 +105,9 @@ TaskNode::TaskNode(int32_t role, ...@@ -101,7 +105,9 @@ TaskNode::TaskNode(int32_t role,
TaskNode::TaskNode(int32_t role, TaskNode::TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops, const std::vector<framework::OperatorBase*>& ops,
int64_t rank, int64_t task_id, int64_t max_run_times, int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums) int64_t max_slot_nums)
: ops_(ops), : ops_(ops),
role_(role), role_(role),
...@@ -110,8 +116,11 @@ TaskNode::TaskNode(int32_t role, ...@@ -110,8 +116,11 @@ TaskNode::TaskNode(int32_t role,
max_run_times_(max_run_times), max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {} max_slot_nums_(max_slot_nums) {}
TaskNode::TaskNode(int32_t role, int64_t rank, int64_t task_id, TaskNode::TaskNode(int32_t role,
int64_t max_run_times, int64_t max_slot_nums) int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: role_(role), : role_(role),
rank_(rank), rank_(rank),
task_id_(task_id), task_id_(task_id),
...@@ -139,14 +148,16 @@ std::string TaskNode::DebugString() const { ...@@ -139,14 +148,16 @@ std::string TaskNode::DebugString() const {
} }
void TaskNode::SetRunPerSteps(int64_t value) { void TaskNode::SetRunPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(value, 1, PADDLE_ENFORCE_GE(value,
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"run_per_steps must >= 1, but received %ld", value)); "run_per_steps must >= 1, but received %ld", value));
run_per_steps_ = value; run_per_steps_ = value;
} }
void TaskNode::SetRunAtOffset(int64_t value) { void TaskNode::SetRunAtOffset(int64_t value) {
PADDLE_ENFORCE_GE(value, 0, PADDLE_ENFORCE_GE(value,
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"run_at_offset must >= 0, but received %ld", value)); "run_at_offset must >= 0, but received %ld", value));
run_at_offset_ = value; run_at_offset_ = value;
...@@ -154,7 +165,8 @@ void TaskNode::SetRunAtOffset(int64_t value) { ...@@ -154,7 +165,8 @@ void TaskNode::SetRunAtOffset(int64_t value) {
void TaskNode::SetReplyUpPerSteps(int64_t value) { void TaskNode::SetReplyUpPerSteps(int64_t value) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
value, 1, value,
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but received %ld", value)); "reply_up_per_steps must >= 1, but received %ld", value));
reply_up_per_steps_ = value; reply_up_per_steps_ = value;
...@@ -162,7 +174,8 @@ void TaskNode::SetReplyUpPerSteps(int64_t value) { ...@@ -162,7 +174,8 @@ void TaskNode::SetReplyUpPerSteps(int64_t value) {
void TaskNode::SetSendDownPerSteps(int64_t value) { void TaskNode::SetSendDownPerSteps(int64_t value) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
value, 1, value,
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but received %ld", value)); "send_down_per_steps must >= 1, but received %ld", value));
send_down_per_steps_ = value; send_down_per_steps_ = value;
......
...@@ -33,16 +33,27 @@ class TaskNode final { ...@@ -33,16 +33,27 @@ class TaskNode final {
public: public:
using OperatorBase = paddle::framework::OperatorBase; using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times); TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times, TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums); int64_t max_slot_nums);
TaskNode(int32_t role, const std::vector<framework::OpDesc*>& op_descs, TaskNode(int32_t role,
int64_t rank, int64_t task_id, int64_t max_run_times, const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums); int64_t max_slot_nums);
TaskNode(int32_t role, const std::vector<framework::OperatorBase*>& ops, TaskNode(int32_t role,
int64_t rank, int64_t task_id, int64_t max_run_times, const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums); int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
int64_t max_run_times, int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank); TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
~TaskNode() = default; ~TaskNode() = default;
......
...@@ -40,12 +40,13 @@ std::vector<framework::OperatorBase*> GetOps() { ...@@ -40,12 +40,13 @@ std::vector<framework::OperatorBase*> GetOps() {
attrs["shape"] = phi::vectorize<int>({2, 3}); attrs["shape"] = phi::vectorize<int>({2, 3});
attrs["value"] = 1.0f; attrs["value"] = 1.0f;
auto zero_op = framework::OpRegistry::CreateOp("fill_constant", {}, auto zero_op = framework::OpRegistry::CreateOp(
{{"Out", {"x"}}}, attrs); "fill_constant", {}, {{"Out", {"x"}}}, attrs);
auto op = framework::OpRegistry::CreateOp( auto op = framework::OpRegistry::CreateOp("elementwise_add",
"elementwise_add", {{"X", {"x"}}, {"Y", {"x"}}}, {{"Out", {"out"}}}, {{"X", {"x"}}, {"Y", {"x"}}},
framework::AttributeMap()); {{"Out", {"out"}}},
framework::AttributeMap());
// NOTE: don't delete // NOTE: don't delete
return {zero_op.release(), op.release()}; return {zero_op.release(), op.release()};
......
...@@ -54,14 +54,15 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -54,14 +54,15 @@ TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0"; std::string carrier_id = "0";
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{SOURCE_ID, 0}, carrier->Init(0,
{0, 0}, {{SOURCE_ID, 0},
{1, 0}, {0, 0},
{2, 0}, {1, 0},
{3, 0}, {2, 0},
{4, 0}, {3, 0},
{5, 0}, {4, 0},
{SINK_ID, 0}}); {5, 0},
{SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
......
...@@ -27,7 +27,8 @@ namespace distributed { ...@@ -27,7 +27,8 @@ namespace distributed {
int64_t GetBuffSize( int64_t GetBuffSize(
const std::map<std::pair<TaskNode*, TaskNode*>, int64_t> buffs, const std::map<std::pair<TaskNode*, TaskNode*>, int64_t> buffs,
TaskNode* from, TaskNode* to) { TaskNode* from,
TaskNode* to) {
if (buffs.find({from, to}) != buffs.end()) { if (buffs.find({from, to}) != buffs.end()) {
return buffs.at({from, to}); return buffs.at({from, to});
} }
......
...@@ -21,7 +21,8 @@ namespace distributed { ...@@ -21,7 +21,8 @@ namespace distributed {
std::vector<std::vector<uint64_t>> LayerWiseSampler::sample( std::vector<std::vector<uint64_t>> LayerWiseSampler::sample(
const std::vector<std::vector<uint64_t>>& user_inputs, const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids, bool with_hierarchy) { const std::vector<uint64_t>& target_ids,
bool with_hierarchy) {
auto input_num = target_ids.size(); auto input_num = target_ids.size();
auto user_feature_num = user_inputs[0].size(); auto user_feature_num = user_inputs[0].size();
std::vector<std::vector<uint64_t>> outputs( std::vector<std::vector<uint64_t>> outputs(
......
...@@ -38,7 +38,8 @@ class IndexSampler { ...@@ -38,7 +38,8 @@ class IndexSampler {
virtual void init_layerwise_conf( virtual void init_layerwise_conf(
const std::vector<uint16_t>& layer_sample_counts, const std::vector<uint16_t>& layer_sample_counts,
uint16_t start_sample_layer = 1, uint16_t seed = 0) {} uint16_t start_sample_layer = 1,
uint16_t seed = 0) {}
virtual void init_beamsearch_conf(const int64_t k) {} virtual void init_beamsearch_conf(const int64_t k) {}
virtual std::vector<std::vector<uint64_t>> sample( virtual std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs, const std::vector<std::vector<uint64_t>>& user_inputs,
...@@ -65,15 +66,18 @@ class LayerWiseSampler : public IndexSampler { ...@@ -65,15 +66,18 @@ class LayerWiseSampler : public IndexSampler {
start_sample_layer_ = start_sample_layer; start_sample_layer_ = start_sample_layer;
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
start_sample_layer_, 0, start_sample_layer_,
0,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should greater than 0.", "start sampler layer = [%d], it should greater than 0.",
start_sample_layer_)); start_sample_layer_));
PADDLE_ENFORCE_LT(start_sample_layer_, tree_->Height(), PADDLE_ENFORCE_LT(start_sample_layer_,
tree_->Height(),
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should less than " "start sampler layer = [%d], it should less than "
"max_layer, which is [%d].", "max_layer, which is [%d].",
start_sample_layer_, tree_->Height())); start_sample_layer_,
tree_->Height()));
size_t i = 0; size_t i = 0;
layer_counts_sum_ = 0; layer_counts_sum_ = 0;
...@@ -113,7 +117,8 @@ class LayerWiseSampler : public IndexSampler { ...@@ -113,7 +117,8 @@ class LayerWiseSampler : public IndexSampler {
} }
std::vector<std::vector<uint64_t>> sample( std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs, const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids, bool with_hierarchy) override; const std::vector<uint64_t>& target_ids,
bool with_hierarchy) override;
void sample_from_dataset( void sample_from_dataset(
const uint16_t sample_slot, const uint16_t sample_slot,
......
...@@ -29,7 +29,8 @@ int TreeIndex::Load(const std::string filename) { ...@@ -29,7 +29,8 @@ int TreeIndex::Load(const std::string filename) {
int err_no; int err_no;
auto fp = paddle::framework::fs_open_read(filename, &err_no, ""); auto fp = paddle::framework::fs_open_read(filename, &err_no, "");
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
fp, nullptr, fp,
nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Open file %s failed. Please check whether the file exists.", "Open file %s failed. Please check whether the file exists.",
filename)); filename));
...@@ -46,17 +47,21 @@ int TreeIndex::Load(const std::string filename) { ...@@ -46,17 +47,21 @@ int TreeIndex::Load(const std::string filename) {
size_t read_num = size_t read_num =
fread(const_cast<char*>(content.data()), 1, num, fp.get()); fread(const_cast<char*>(content.data()), 1, num, fp.get());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
read_num, static_cast<size_t>(num), read_num,
static_cast<size_t>(num),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Read from file: %s failed. Valid Format is " "Read from file: %s failed. Valid Format is "
"an integer representing the length of the following string, " "an integer representing the length of the following string, "
"and the string itself.We got an iteger[% d], " "and the string itself.We got an iteger[% d], "
"but the following string's length is [%d].", "but the following string's length is [%d].",
filename, num, read_num)); filename,
num,
read_num));
KVItem item; KVItem item;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
item.ParseFromString(content), true, item.ParseFromString(content),
true,
platform::errors::InvalidArgument("Parse from file: %s failed. It's " platform::errors::InvalidArgument("Parse from file: %s failed. It's "
"content can't be parsed by KVItem.", "content can't be parsed by KVItem.",
filename)); filename));
...@@ -168,7 +173,8 @@ std::vector<uint64_t> TreeIndex::GetChildrenCodes(uint64_t ancestor, ...@@ -168,7 +173,8 @@ std::vector<uint64_t> TreeIndex::GetChildrenCodes(uint64_t ancestor,
std::vector<uint64_t> TreeIndex::GetTravelCodes(uint64_t id, int start_level) { std::vector<uint64_t> TreeIndex::GetTravelCodes(uint64_t id, int start_level) {
std::vector<uint64_t> res; std::vector<uint64_t> res;
PADDLE_ENFORCE_NE(id_codes_map_.find(id), id_codes_map_.end(), PADDLE_ENFORCE_NE(id_codes_map_.find(id),
id_codes_map_.end(),
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"id = %d doesn't exist in Tree.", id)); "id = %d doesn't exist in Tree.", id));
auto code = id_codes_map_.at(id); auto code = id_codes_map_.at(id);
......
...@@ -76,7 +76,8 @@ class IndexWrapper { ...@@ -76,7 +76,8 @@ class IndexWrapper {
void clear_tree() { tree_map.clear(); } void clear_tree() { tree_map.clear(); }
TreePtr get_tree_index(const std::string name) { TreePtr get_tree_index(const std::string name) {
PADDLE_ENFORCE_NE(tree_map.find(name), tree_map.end(), PADDLE_ENFORCE_NE(tree_map.find(name),
tree_map.end(),
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"tree [%s] doesn't exist. Please insert it firstly " "tree [%s] doesn't exist. Please insert it firstly "
"by API[\' insert_tree_index \'].", "by API[\' insert_tree_index \'].",
...@@ -91,11 +92,13 @@ class IndexWrapper { ...@@ -91,11 +92,13 @@ class IndexWrapper {
} }
TreePtr tree = std::make_shared<TreeIndex>(); TreePtr tree = std::make_shared<TreeIndex>();
int ret = tree->Load(tree_path); int ret = tree->Load(tree_path);
PADDLE_ENFORCE_EQ(ret, 0, PADDLE_ENFORCE_EQ(ret,
0,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"Load tree[%s] from path[%s] failed. Please " "Load tree[%s] from path[%s] failed. Please "
"check whether the file exists.", "check whether the file exists.",
name, tree_path)); name,
tree_path));
tree_map.insert(std::pair<std::string, TreePtr>{name, tree}); tree_map.insert(std::pair<std::string, TreePtr>{name, tree});
} }
......
...@@ -57,7 +57,8 @@ class DownpourPsClientService : public PsService { ...@@ -57,7 +57,8 @@ class DownpourPsClientService : public PsService {
return 0; return 0;
} }
void service(::google::protobuf::RpcController *controller, void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, PsResponseMessage *response, const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
protected: protected:
...@@ -162,13 +163,15 @@ class BrpcPsClient : public PSClient { ...@@ -162,13 +163,15 @@ class BrpcPsClient : public PSClient {
const std::string threshold) override; const std::string threshold) override;
std::future<int32_t> Load(const std::string &epoch, std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> Load(uint32_t table_id, const std::string &epoch, std::future<int32_t> Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> Save(const std::string &epoch, std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> Save(uint32_t table_id, const std::string &epoch, std::future<int32_t> Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> Clear() override; std::future<int32_t> Clear() override;
...@@ -182,7 +185,8 @@ class BrpcPsClient : public PSClient { ...@@ -182,7 +185,8 @@ class BrpcPsClient : public PSClient {
void FinalizeWorker() override; void FinalizeWorker() override;
virtual std::future<int32_t> PullDense(Region *regions, size_t region_num, virtual std::future<int32_t> PullDense(Region *regions,
size_t region_num,
size_t table_id); size_t table_id);
virtual std::future<int32_t> PushDenseParam(const Region *regions, virtual std::future<int32_t> PushDenseParam(const Region *regions,
...@@ -190,14 +194,18 @@ class BrpcPsClient : public PSClient { ...@@ -190,14 +194,18 @@ class BrpcPsClient : public PSClient {
size_t table_id); size_t table_id);
virtual std::future<int32_t> PushDense(const Region *regions, virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num, size_t table_id); size_t region_num,
size_t table_id);
void PushDenseTaskConsume(); void PushDenseTaskConsume();
virtual std::future<int32_t> PullSparse(float **select_values, virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, const uint64_t *keys, size_t table_id,
size_t num, bool is_training); const uint64_t *keys,
size_t num,
bool is_training);
virtual std::future<int32_t> PullSparseParam(float **select_values, virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id, size_t table_id,
const uint64_t *keys, size_t num, const uint64_t *keys,
size_t num,
bool is_training); bool is_training);
virtual std::future<int32_t> PrintTableStat(uint32_t table_id); virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
...@@ -213,7 +221,8 @@ class BrpcPsClient : public PSClient { ...@@ -213,7 +221,8 @@ class BrpcPsClient : public PSClient {
void *done); void *done);
virtual std::future<int32_t> Flush(); virtual std::future<int32_t> Flush();
std::future<int32_t> SendClient2ClientMsg(int msg_type, int to_client_id, std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string &msg) override; const std::string &msg) override;
// for local save sparse // for local save sparse
...@@ -221,14 +230,19 @@ class BrpcPsClient : public PSClient { ...@@ -221,14 +230,19 @@ class BrpcPsClient : public PSClient {
const std::string &path); const std::string &path);
std::future<int32_t> CacheShuffle( std::future<int32_t> CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode, uint32_t table_id,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) override; const std::string &cache_threshold) override;
std::future<int32_t> CacheShuffleMultiTable( std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode, std::vector<int> tables,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold); const std::string &cache_threshold);
std::future<int32_t> SaveCache(uint32_t table_id, const std::string &path, std::future<int32_t> SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> GetCacheThreshold(uint32_t table_id, std::future<int32_t> GetCacheThreshold(uint32_t table_id,
...@@ -256,10 +270,12 @@ class BrpcPsClient : public PSClient { ...@@ -256,10 +270,12 @@ class BrpcPsClient : public PSClient {
return dense_dim_total / shard_num + 1; return dense_dim_total / shard_num + 1;
} }
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id, std::future<int32_t> SendCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string> &param); const std::vector<std::string> &param);
std::future<int32_t> SendSaveCmd(uint32_t table_id, int cmd_id, std::future<int32_t> SendSaveCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string> &param); const std::vector<std::string> &param);
bool _running = false; bool _running = false;
...@@ -281,14 +297,19 @@ class BrpcPsClient : public PSClient { ...@@ -281,14 +297,19 @@ class BrpcPsClient : public PSClient {
std::thread _print_thread; std::thread _print_thread;
int PushSparseAsyncShardMerge( int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT std::vector<int> &request_kv_num,
int table_id,
int shard_idx, // NOLINT
ValueAccessor *accessor); ValueAccessor *accessor);
int PushSparseAsyncShardPush( int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT std::vector<int> &request_kv_num,
DownpourBrpcClosure *closure, ValueAccessor *accessor); int table_id,
int shard_idx, // NOLINT
DownpourBrpcClosure *closure,
ValueAccessor *accessor);
SparseTaskPool _sparse_task_pool; SparseTaskPool _sparse_task_pool;
...@@ -304,18 +325,23 @@ class BrpcPsClient : public PSClient { ...@@ -304,18 +325,23 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> PushSparseRawGradient(size_t table_id, std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t *keys, const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num, void *done) override; size_t num,
void *done) override;
std::future<int32_t> PushSparseRawGradientPartial(size_t table_id, std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
const uint64_t *keys, const uint64_t *keys,
const float **update_values, const float **update_values,
uint32_t num, void *done, uint32_t num,
void *done,
int pserver_idx) override; int pserver_idx) override;
std::future<int32_t> PushSparseParam(size_t table_id, const uint64_t *keys, std::future<int32_t> PushSparseParam(size_t table_id,
const float **update_values, size_t num, const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override; void *done) override;
std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys, std::future<int32_t> PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num) override; size_t num) override;
void PushSparseTaskConsume(); void PushSparseTaskConsume();
...@@ -324,7 +350,8 @@ class BrpcPsClient : public PSClient { ...@@ -324,7 +350,8 @@ class BrpcPsClient : public PSClient {
int32_t StartClientService(); int32_t StartClientService();
void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data, size_t total_send_data_size, float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure); DownpourBrpcClosure *closure);
float _mae = 0; float _mae = 0;
float _mse = 0; float _mse = 0;
......
...@@ -30,11 +30,14 @@ class RpcController; ...@@ -30,11 +30,14 @@ class RpcController;
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
DEFINE_int32(pserver_timeout_ms_s2s, 10000, DEFINE_int32(pserver_timeout_ms_s2s,
10000,
"pserver request server timeout_ms"); "pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms_s2s, 10000, DEFINE_int32(pserver_connect_timeout_ms_s2s,
10000,
"pserver connect server timeout_ms"); "pserver connect server timeout_ms");
DEFINE_string(pserver_connection_type_s2s, "pooled", DEFINE_string(pserver_connection_type_s2s,
"pooled",
"pserver connection_type[pooled:single]"); "pserver connection_type[pooled:single]");
namespace paddle { namespace paddle {
...@@ -154,12 +157,13 @@ std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg( ...@@ -154,12 +157,13 @@ std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
closure->request(0)->set_table_id(0); closure->request(0)->set_table_id(0);
closure->request(0)->set_data(msg); closure->request(0)->set_data(msg);
PsService_Stub rpc_stub(_pserver_channels[to_pserver_id].get()); PsService_Stub rpc_stub(_pserver_channels[to_pserver_id].get());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), rpc_stub.service(
closure); closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut; return fut;
} }
int32_t BrpcPsServer::ReceiveFromPServer(int msg_type, int pserver_id, int32_t BrpcPsServer::ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) { const std::string &msg) {
if (msg.length() == 0) { if (msg.length() == 0) {
LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response"; LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response";
...@@ -289,7 +293,8 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base, ...@@ -289,7 +293,8 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
} }
} }
int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::PullDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
...@@ -297,7 +302,8 @@ int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request, ...@@ -297,7 +302,8 @@ int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) { if (request.params_size() < 1) {
set_response_code( set_response_code(
response, -1, response,
-1,
"PsRequestMessage.datas is requeired at least 1 for num of dense"); "PsRequestMessage.datas is requeired at least 1 for num of dense");
return 0; return 0;
} }
...@@ -357,7 +363,8 @@ int32_t BrpcPsService::PushDenseParam(Table *table, ...@@ -357,7 +363,8 @@ int32_t BrpcPsService::PushDenseParam(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::PushDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
...@@ -391,13 +398,15 @@ int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request, ...@@ -391,13 +398,15 @@ int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t BrpcPsService::Barrier(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) { if (request.params_size() < 1) {
set_response_code(response, -1, set_response_code(response,
-1,
"PsRequestMessage.params is requeired at " "PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"); "least 1 for num of sparse_key");
return 0; return 0;
...@@ -423,7 +432,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table, ...@@ -423,7 +432,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table,
return 0; return 0;
} }
if (request.params_size() < 1) { if (request.params_size() < 1) {
set_response_code(response, -1, set_response_code(response,
-1,
"PsRequestMessage.params is requeired at " "PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"); "least 1 for num of sparse_key");
return 0; return 0;
...@@ -482,7 +492,8 @@ int32_t BrpcPsService::PullGeoParam(Table *table, ...@@ -482,7 +492,8 @@ int32_t BrpcPsService::PullGeoParam(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::PullSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
...@@ -498,7 +509,8 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request, ...@@ -498,7 +509,8 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
} }
if (request.params_size() < 1) { if (request.params_size() < 1) {
set_response_code(response, -1, set_response_code(response,
-1,
"PsRequestMessage.params is requeired at " "PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"); "least 1 for num of sparse_key");
return 0; return 0;
...@@ -533,7 +545,8 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request, ...@@ -533,7 +545,8 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::PushSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
...@@ -545,7 +558,8 @@ int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request, ...@@ -545,7 +558,8 @@ int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
if (request.params_size() < 1) { if (request.params_size() < 1) {
set_response_code(response, -1, set_response_code(response,
-1,
"PsRequestMessage.params is requeired at " "PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"); "least 1 for num of sparse_key");
return 0; return 0;
...@@ -594,7 +608,8 @@ int32_t BrpcPsService::LoadOneTable(Table *table, ...@@ -594,7 +608,8 @@ int32_t BrpcPsService::LoadOneTable(Table *table,
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
response, -1, response,
-1,
"PsRequestMessage.datas is requeired at least 2 for path & load_param"); "PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1; return -1;
} }
...@@ -626,7 +641,8 @@ int32_t BrpcPsService::SaveOneTable(Table *table, ...@@ -626,7 +641,8 @@ int32_t BrpcPsService::SaveOneTable(Table *table,
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
response, -1, response,
-1,
"PsRequestMessage.datas is requeired at least 2, path&mode"); "PsRequestMessage.datas is requeired at least 2, path&mode");
return -1; return -1;
} }
...@@ -667,7 +683,8 @@ int32_t BrpcPsService::SaveCacheTable(Table *table, ...@@ -667,7 +683,8 @@ int32_t BrpcPsService::SaveCacheTable(Table *table,
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
response, -1, response,
-1,
"PsRequestMessage.datas is requeired at least 3, path&mode"); "PsRequestMessage.datas is requeired at least 3, path&mode");
return -1; return -1;
} }
...@@ -676,8 +693,8 @@ int32_t BrpcPsService::SaveCacheTable(Table *table, ...@@ -676,8 +693,8 @@ int32_t BrpcPsService::SaveCacheTable(Table *table,
// if (_server->_shuffled_ins->size() <= 0) { // if (_server->_shuffled_ins->size() <= 0) {
// LOG(WARNING) << "shuffled ins size <= 0"; // LOG(WARNING) << "shuffled ins size <= 0";
//} //}
feasign_size = table->SaveCache(request.params(0), request.params(1), feasign_size = table->SaveCache(
_server->_shuffled_ins); request.params(0), request.params(1), _server->_shuffled_ins);
if (feasign_size < 0) { if (feasign_size < 0) {
set_response_code(response, -1, "table save failed"); set_response_code(response, -1, "table save failed");
return -1; return -1;
...@@ -692,7 +709,8 @@ int32_t BrpcPsService::CacheShuffle(Table *table, ...@@ -692,7 +709,8 @@ int32_t BrpcPsService::CacheShuffle(Table *table,
// start cache shuffle // start cache shuffle
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) { if (request.params_size() < 3) {
set_response_code(response, -1, set_response_code(response,
-1,
"PsRequestMessage.datas is requeired at least 3, " "PsRequestMessage.datas is requeired at least 3, "
"path&mode&cache_threshold"); "path&mode&cache_threshold");
return -1; return -1;
...@@ -704,9 +722,10 @@ int32_t BrpcPsService::CacheShuffle(Table *table, ...@@ -704,9 +722,10 @@ int32_t BrpcPsService::CacheShuffle(Table *table,
// std::string>>(); // std::string>>();
// shuffled_ins->set_block_size(80000); // shuffled_ins->set_block_size(80000);
_server->StartS2S(); _server->StartS2S();
std::function<std::future<int32_t>(int msg_type, int to_pserver_id, std::function<std::future<int32_t>(
const std::string &msg)> int msg_type, int to_pserver_id, const std::string &msg)>
send_msg_func = [this](int msg_type, int to_pserver_id, send_msg_func = [this](int msg_type,
int to_pserver_id,
const std::string &msg) -> std::future<int32_t> { const std::string &msg) -> std::future<int32_t> {
return this->_server->SendPServer2PServerMsg(msg_type, to_pserver_id, msg); return this->_server->SendPServer2PServerMsg(msg_type, to_pserver_id, msg);
}; };
...@@ -721,8 +740,12 @@ int32_t BrpcPsService::CacheShuffle(Table *table, ...@@ -721,8 +740,12 @@ int32_t BrpcPsService::CacheShuffle(Table *table,
table_ptrs.push_back(table); table_ptrs.push_back(table);
} }
table->CacheShuffle(request.params(0), request.params(1), cache_threshold, table->CacheShuffle(request.params(0),
send_msg_func, _server->_shuffled_ins, table_ptrs); request.params(1),
cache_threshold,
send_msg_func,
_server->_shuffled_ins,
table_ptrs);
return 0; return 0;
} }
...@@ -751,7 +774,8 @@ int32_t BrpcPsService::ShrinkTable(Table *table, ...@@ -751,7 +774,8 @@ int32_t BrpcPsService::ShrinkTable(Table *table,
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) { if (request.params_size() < 1) {
set_response_code( set_response_code(
response, -1, response,
-1,
"PsRequestMessage.datas is requeired at least 1, threshold"); "PsRequestMessage.datas is requeired at least 1, threshold");
return -1; return -1;
} }
...@@ -787,7 +811,8 @@ int32_t BrpcPsService::ClearAllTable(Table *table, ...@@ -787,7 +811,8 @@ int32_t BrpcPsService::ClearAllTable(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::StopServer(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
auto *p_server = _server; auto *p_server = _server;
......
...@@ -56,7 +56,8 @@ class BrpcPsServer : public PSServer { ...@@ -56,7 +56,8 @@ class BrpcPsServer : public PSServer {
virtual int32_t StartS2S() override; virtual int32_t StartS2S() override;
virtual ::std::future<int32_t> SendPServer2PServerMsg( virtual ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override; int msg_type, int to_pserver_id, const std::string &msg) override;
virtual int32_t ReceiveFromPServer(int msg_type, int pserver_id, virtual int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) override; const std::string &msg) override;
private: private:
...@@ -72,7 +73,9 @@ class BrpcPsServer : public PSServer { ...@@ -72,7 +73,9 @@ class BrpcPsServer : public PSServer {
class BrpcPsService; class BrpcPsService;
typedef int32_t (BrpcPsService::*serviceHandlerFunc)( typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response, Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl); brpc::Controller *cntl);
class BrpcPsService : public PsBaseService { class BrpcPsService : public PsBaseService {
...@@ -86,56 +89,101 @@ class BrpcPsService : public PsBaseService { ...@@ -86,56 +89,101 @@ class BrpcPsService : public PsBaseService {
private: private:
int32_t InitializeShardInfo(); int32_t InitializeShardInfo();
int32_t PullDense(Table *table, const PsRequestMessage &request, int32_t PullDense(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
int32_t PushDense(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushDenseParam(Table *table, const PsRequestMessage &request, int32_t PushDense(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
int32_t PushSparseParam(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PullSparse(Table *table, const PsRequestMessage &request, int32_t PushDenseParam(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
int32_t PullGeoParam(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t Barrier(Table *table, const PsRequestMessage &request, int32_t PushSparseParam(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
int32_t PushSparse(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t LoadOneTable(Table *table, const PsRequestMessage &request, int32_t PullSparse(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
int32_t LoadAllTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t SaveOneTable(Table *table, const PsRequestMessage &request, int32_t PullGeoParam(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
int32_t SaveAllTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t ShrinkTable(Table *table, const PsRequestMessage &request, int32_t Barrier(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
int32_t ClearOneTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t ClearAllTable(Table *table, const PsRequestMessage &request, int32_t PushSparse(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
int32_t StopServer(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t StartProfiler(Table *table, const PsRequestMessage &request, int32_t LoadOneTable(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
int32_t StopProfiler(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t LoadAllTable(Table *table,
int32_t PrintTableStat(Table *table, const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table, const PsRequestMessage &request, int32_t SaveOneTable(Table *table,
PsResponseMessage &response, brpc::Controller *cntl); const PsRequestMessage &request,
PsResponseMessage &response,
int32_t CacheShuffle(Table *table, const PsRequestMessage &request, brpc::Controller *cntl);
PsResponseMessage &response, brpc::Controller *cntl); int32_t SaveAllTable(Table *table,
const PsRequestMessage &request,
int32_t SaveCacheTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl); brpc::Controller *cntl);
int32_t ShrinkTable(Table *table,
int32_t GetCacheThreshold(Table *table, const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t ClearOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t ClearAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl); brpc::Controller *cntl);
......
...@@ -56,8 +56,10 @@ void SerializeToMultiVarMsgAndIOBuf( ...@@ -56,8 +56,10 @@ void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name, const std::string& message_name,
const std::vector<std::string>& send_var_name_val, const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val, const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx, const framework::Scope* scope, const platform::DeviceContext& ctx,
MultiVarMsg* request, butil::IOBuf* iobuf) { const framework::Scope* scope,
MultiVarMsg* request,
butil::IOBuf* iobuf) {
// 1. message_name // 1. message_name
request->set_message_name(message_name); request->set_message_name(message_name);
...@@ -87,7 +89,8 @@ void SerializeToMultiVarMsgAndIOBuf( ...@@ -87,7 +89,8 @@ void SerializeToMultiVarMsgAndIOBuf(
} }
void SerializeLodTensor(framework::Variable* var, void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg, const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf) { butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::distributed::LOD_TENSOR); var_msg->set_type(::paddle::distributed::LOD_TENSOR);
...@@ -119,7 +122,10 @@ void SerializeLodTensor(framework::Variable* var, ...@@ -119,7 +122,10 @@ void SerializeLodTensor(framework::Variable* var,
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy( memory::Copy(
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(), platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::SizeOfType( tensor->numel() * framework::SizeOfType(
framework::TransToProtoVarType(tensor->dtype())), framework::TransToProtoVarType(tensor->dtype())),
stream); stream);
...@@ -132,7 +138,8 @@ void SerializeLodTensor(framework::Variable* var, ...@@ -132,7 +138,8 @@ void SerializeLodTensor(framework::Variable* var,
} }
void SerializeSelectedRows(framework::Variable* var, void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg, const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf) { butil::IOBuf* iobuf) {
phi::SelectedRows* slr = var->GetMutable<phi::SelectedRows>(); phi::SelectedRows* slr = var->GetMutable<phi::SelectedRows>();
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
...@@ -164,7 +171,10 @@ void SerializeSelectedRows(framework::Variable* var, ...@@ -164,7 +171,10 @@ void SerializeSelectedRows(framework::Variable* var,
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy( memory::Copy(
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(), platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::SizeOfType( tensor->numel() * framework::SizeOfType(
framework::TransToProtoVarType(tensor->dtype())), framework::TransToProtoVarType(tensor->dtype())),
stream); stream);
...@@ -204,7 +214,8 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, ...@@ -204,7 +214,8 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
++recv_var_index) { ++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index); const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->FindVar(msg.varname()); auto* var = scope->FindVar(msg.varname());
PADDLE_ENFORCE_NE(var, nullptr, PADDLE_ENFORCE_NE(var,
nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname())); "Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::distributed::LOD_TENSOR) { if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
...@@ -215,7 +226,8 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, ...@@ -215,7 +226,8 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
} }
} }
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, void DeserializeLodTensor(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) { const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace(); const auto place = ctx.GetPlace();
...@@ -255,16 +267,20 @@ void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, ...@@ -255,16 +267,20 @@ void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); // NOLINT io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); // NOLINT
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy( memory::Copy(place,
place, tensor_data, platform::CPUPlace(), (void*)temp_ptr, // NOLINT tensor_data,
tensor->numel() * framework::DataTypeSize(tensor->dtype()), stream); platform::CPUPlace(),
(void*)temp_ptr, // NOLINT
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream);
delete[] temp_ptr; delete[] temp_ptr;
#endif #endif
} }
} }
void DeserializeSelectedRows( void DeserializeSelectedRows(
framework::Variable* var, const VarMsg& msg, framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) { const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace(); const auto place = ctx.GetPlace();
...@@ -297,7 +313,10 @@ void DeserializeSelectedRows( ...@@ -297,7 +313,10 @@ void DeserializeSelectedRows(
io_buffer_itr.copy_and_forward(temp_ptr, data_len); io_buffer_itr.copy_and_forward(temp_ptr, data_len);
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(place, tensor_data, platform::CPUPlace(), temp_ptr, memory::Copy(place,
tensor_data,
platform::CPUPlace(),
temp_ptr,
tensor->numel() * framework::DataTypeSize(tensor->dtype()), tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream); stream);
delete[] temp_ptr; delete[] temp_ptr;
......
...@@ -55,15 +55,19 @@ void SerializeToMultiVarMsgAndIOBuf( ...@@ -55,15 +55,19 @@ void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name, const std::string& message_name,
const std::vector<std::string>& send_var_name_val, const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val, const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx, const framework::Scope* scope, const platform::DeviceContext& ctx,
MultiVarMsg* var_msg, butil::IOBuf* iobuf); const framework::Scope* scope,
MultiVarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeLodTensor(framework::Variable* var, void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg, const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf); butil::IOBuf* iobuf);
void SerializeSelectedRows(framework::Variable* var, void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request, const platform::DeviceContext& ctx,
VarMsg* request,
butil::IOBuf* iobuf); butil::IOBuf* iobuf);
// Deserialize for Server // Deserialize for Server
...@@ -78,11 +82,13 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, ...@@ -78,11 +82,13 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope); const framework::Scope* scope);
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, void DeserializeLodTensor(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& iobuf, // NOLINT butil::IOBufBytesIterator& iobuf, // NOLINT
const platform::DeviceContext& ctx); const platform::DeviceContext& ctx);
void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, void DeserializeSelectedRows(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& iobuf, // NOLINT butil::IOBufBytesIterator& iobuf, // NOLINT
const platform::DeviceContext& ctx); const platform::DeviceContext& ctx);
......
...@@ -88,7 +88,8 @@ int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) { ...@@ -88,7 +88,8 @@ int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
} }
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope) { int table_id,
Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvDense", platform::RecordEvent record_event("Communicator->RpcRecvDense",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
...@@ -146,7 +147,8 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, ...@@ -146,7 +147,8 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
} }
void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames, void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id, const Scope &scope) { int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendDenseParam", platform::RecordEvent record_event("Communicator->RpcSendDenseParam",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
...@@ -224,13 +226,14 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { ...@@ -224,13 +226,14 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
closure->set_promise_value(ret); closure->set_promise_value(ret);
--_async_call_num; --_async_call_num;
}); });
auto status = _worker_ptr->PushDenseRawGradient(table_id, data, auto status = _worker_ptr->PushDenseRawGradient(
dense_data->size(), closure); table_id, data, dense_data->size(), closure);
status.wait(); status.wait();
return; return;
} }
void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, void Communicator::RpcSendSparseParam(const std::string &varname,
int table_id,
const Scope &scope) { const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparseParam", platform::RecordEvent record_event("Communicator->RpcSendSparseParam",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
...@@ -262,14 +265,17 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, ...@@ -262,14 +265,17 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
auto status = _worker_ptr->PushSparseParam(table_id, sparse_push_keys.data(), auto status = _worker_ptr->PushSparseParam(table_id,
sparse_push_keys.data(),
(const float **)push_g_vec.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure); sparse_push_keys.size(),
closure);
status.wait(); status.wait();
return; return;
} }
void Communicator::RpcSendSparse(const std::string &var_name, int table_id, void Communicator::RpcSendSparse(const std::string &var_name,
int table_id,
const Scope &scope) { const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparse", platform::RecordEvent record_event("Communicator->RpcSendSparse",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
...@@ -281,7 +287,8 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, ...@@ -281,7 +287,8 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
auto *send_var = scope.FindVar(var_name); auto *send_var = scope.FindVar(var_name);
auto *tensor = send_var->GetMutable<phi::SelectedRows>(); auto *tensor = send_var->GetMutable<phi::SelectedRows>();
auto dim = tensor->value().dims()[1]; auto dim = tensor->value().dims()[1];
std::transform(tensor->rows().begin(), tensor->rows().end(), std::transform(tensor->rows().begin(),
tensor->rows().end(),
std::back_inserter(sparse_push_keys), std::back_inserter(sparse_push_keys),
[&](int64_t id) { return static_cast<uint64_t>(id); }); [&](int64_t id) { return static_cast<uint64_t>(id); });
...@@ -315,14 +322,18 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, ...@@ -315,14 +322,18 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
closure->set_promise_value(ret); closure->set_promise_value(ret);
--_async_call_num; --_async_call_num;
}); });
auto status = _worker_ptr->PushSparseRawGradient( auto status =
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), _worker_ptr->PushSparseRawGradient(table_id,
sparse_push_keys.size(), closure); sparse_push_keys.data(),
(const float **)push_g_vec.data(),
sparse_push_keys.size(),
closure);
status.wait(); status.wait();
return; return;
} }
void Communicator::RpcRecvSparse(const std::string &varname, int table_id, void Communicator::RpcRecvSparse(const std::string &varname,
int table_id,
Scope *scope) { Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvSparse", platform::RecordEvent record_event("Communicator->RpcRecvSparse",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
...@@ -342,9 +353,11 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id, ...@@ -342,9 +353,11 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
bool training = true; bool training = true;
auto status = _worker_ptr->PullSparseParam( auto status = _worker_ptr->PullSparseParam((float **)push_g_vec.data(),
(float **)push_g_vec.data(), table_id, // NOLINT table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size(), training); sparse_push_keys.data(),
sparse_push_keys.size(),
training);
status.wait(); status.wait();
return; return;
} }
...@@ -389,7 +402,8 @@ void Communicator::RpcProfilerControl() { ...@@ -389,7 +402,8 @@ void Communicator::RpcProfilerControl() {
} }
} }
void Communicator::SendGlobalStep(const CommContext &ctx, int batches, void Communicator::SendGlobalStep(const CommContext &ctx,
int batches,
Scope *send_scope) { Scope *send_scope) {
if (batches == 0) { if (batches == 0) {
return; return;
...@@ -522,7 +536,8 @@ void AsyncCommunicator::SendByCommunicator() { ...@@ -522,7 +536,8 @@ void AsyncCommunicator::SendByCommunicator() {
SendGlobalStep(ctx, merged_var_num, send_scope_.get()); SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) { } else if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
varnames.size(), 1, varnames.size(),
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables")); "sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_); RpcSendSparse(varnames[0], table_id, *send_scope_);
...@@ -569,9 +584,13 @@ void AsyncCommunicator::MainThread() { ...@@ -569,9 +584,13 @@ void AsyncCommunicator::MainThread() {
} }
void AsyncCommunicator::PullSparseToTensorSync( void AsyncCommunicator::PullSparseToTensorSync(
const uint64_t table_id, int fea_dim, uint64_t padding_id, const uint64_t table_id,
platform::Place place, bool is_training, int fea_dim,
std::vector<const LoDTensor *> *inputs, std::vector<LoDTensor *> *outputs) { uint64_t padding_id,
platform::Place place,
bool is_training,
std::vector<const LoDTensor *> *inputs,
std::vector<LoDTensor *> *outputs) {
std::vector<uint64_t> fea_keys; std::vector<uint64_t> fea_keys;
std::vector<float *> pull_result_ptr; std::vector<float *> pull_result_ptr;
fea_keys.reserve(MAX_FEASIGN_NUM / 100); fea_keys.reserve(MAX_FEASIGN_NUM / 100);
...@@ -598,7 +617,8 @@ void AsyncCommunicator::PullSparseToTensorSync( ...@@ -598,7 +617,8 @@ void AsyncCommunicator::PullSparseToTensorSync(
} }
uint64_t real_id = static_cast<uint64_t>(ids[i]); uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) { if (real_id == padding_id) {
memcpy(output_data + output_len, init_value.data(), memcpy(output_data + output_len,
init_value.data(),
sizeof(float) * fea_dim); sizeof(float) * fea_dim);
continue; continue;
} }
...@@ -606,9 +626,11 @@ void AsyncCommunicator::PullSparseToTensorSync( ...@@ -606,9 +626,11 @@ void AsyncCommunicator::PullSparseToTensorSync(
pull_result_ptr.push_back(output_data + output_len); pull_result_ptr.push_back(output_data + output_len);
} }
} }
auto status = auto status = _worker_ptr->PullSparse(pull_result_ptr.data(),
_worker_ptr->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(), table_id,
fea_keys.size(), is_training); fea_keys.data(),
fea_keys.size(),
is_training);
status.wait(); status.wait();
auto ret = status.get(); auto ret = status.get();
if (ret != 0) { if (ret != 0) {
...@@ -618,9 +640,13 @@ void AsyncCommunicator::PullSparseToTensorSync( ...@@ -618,9 +640,13 @@ void AsyncCommunicator::PullSparseToTensorSync(
} }
void AsyncCommunicator::PushSparseFromTensorAsync( void AsyncCommunicator::PushSparseFromTensorAsync(
const uint64_t table_id, int fea_dim, uint64_t padding_id, const uint64_t table_id,
platform::Place place, std::vector<const framework::LoDTensor *> *inputs, int fea_dim,
const framework::LoDTensor *shows, const framework::LoDTensor *clks, uint64_t padding_id,
platform::Place place,
std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows,
const framework::LoDTensor *clks,
std::vector<framework::LoDTensor *> *outputs) { std::vector<framework::LoDTensor *> *outputs) {
int batch_size = -1; int batch_size = -1;
bool batch_size_consist = true; bool batch_size_consist = true;
...@@ -735,10 +761,12 @@ void AsyncCommunicator::PushSparseFromTensorAsync( ...@@ -735,10 +761,12 @@ void AsyncCommunicator::PushSparseFromTensorAsync(
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
this->Check(table_id), true, this->Check(table_id),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id)); "can not find table: %s, please check your config", table_id));
auto status = _worker_ptr->PushSparse(table_id, push_keys.data(), auto status = _worker_ptr->PushSparse(table_id,
push_keys.data(),
(const float **)push_g_vec.data(), (const float **)push_g_vec.data(),
push_keys.size()); push_keys.size());
} }
...@@ -831,7 +859,8 @@ void AsyncCommunicator::Stop() { ...@@ -831,7 +859,8 @@ void AsyncCommunicator::Stop() {
bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) { bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var_tables.size(), 1, var_tables.size(),
1,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted")); platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0]; auto table_name = var_tables[0];
...@@ -948,7 +977,8 @@ void HalfAsyncCommunicator::SendByCommunicator() { ...@@ -948,7 +977,8 @@ void HalfAsyncCommunicator::SendByCommunicator() {
if (ctx.is_sparse) { if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
varnames.size(), 1, varnames.size(),
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables")); "sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_); RpcSendSparse(varnames[0], table_id, *send_scope_);
...@@ -1003,7 +1033,8 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names, ...@@ -1003,7 +1033,8 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
auto *var = scope.FindVar(table_name); auto *var = scope.FindVar(table_name);
PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(), true, PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Only need to send Sparse Grad in Geo mode.")); "Only need to send Sparse Grad in Geo mode."));
auto &rows = var->Get<phi::SelectedRows>().rows(); auto &rows = var->Get<phi::SelectedRows>().rows();
...@@ -1037,7 +1068,8 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, ...@@ -1037,7 +1068,8 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
recv_scope_ = std::move(recv_scope); recv_scope_ = std::move(recv_scope);
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
send_varname_to_ctx.size(), 0, send_varname_to_ctx.size(),
0,
platform::errors::InvalidArgument("send var contexts can not be zero")); platform::errors::InvalidArgument("send var contexts can not be zero"));
for (auto &iter : send_varname_to_ctx_) { for (auto &iter : send_varname_to_ctx_) {
...@@ -1048,14 +1080,16 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, ...@@ -1048,14 +1080,16 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
} }
auto &varnames = ctx.origin_varnames; auto &varnames = ctx.origin_varnames;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
varnames.size(), 1, varnames.size(),
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables")); "sparse variables can only be merged by one variables"));
for (auto &splited_var : ctx.splited_varnames) { for (auto &splited_var : ctx.splited_varnames) {
parallel_task_nums_ += 1; parallel_task_nums_ += 1;
sparse_id_queues_.insert( sparse_id_queues_.insert(
std::pair<std::string, paddle::framework::Channel< std::pair<std::string,
std::shared_ptr<std::vector<int64_t>>>>( paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>(
splited_var, splited_var,
paddle::framework::MakeChannel< paddle::framework::MakeChannel<
std::shared_ptr<std::vector<int64_t>>>(send_queue_size_))); std::shared_ptr<std::vector<int64_t>>>(send_queue_size_)));
...@@ -1138,10 +1172,12 @@ void GeoCommunicator::SendDense(const CommContext &send_ctx) { ...@@ -1138,10 +1172,12 @@ void GeoCommunicator::SendDense(const CommContext &send_ctx) {
auto *var_latest = recv_scope_->FindVar(param_name); auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_timestamp = old_scope_->FindVar(param_name); auto *var_timestamp = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, PADDLE_ENFORCE_EQ(var_latest->IsInitialized(),
true,
platform::errors::Unavailable( platform::errors::Unavailable(
"%s is not initialized, please check", param_name)); "%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true, PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(),
true,
platform::errors::Unavailable( platform::errors::Unavailable(
"%s is not initialized, please check", param_name)); "%s is not initialized, please check", param_name));
...@@ -1154,14 +1190,18 @@ void GeoCommunicator::SendDense(const CommContext &send_ctx) { ...@@ -1154,14 +1190,18 @@ void GeoCommunicator::SendDense(const CommContext &send_ctx) {
t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace()); t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx); auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
blas.VSUB(t_latest.numel(), t_latest.data<float>(), blas.VSUB(t_latest.numel(),
t_timestamp->data<float>(), t_delta->data<float>()); t_latest.data<float>(),
t_timestamp->data<float>(),
t_delta->data<float>());
float coefficient = 1.0 / static_cast<float>(trainers_); float coefficient = 1.0 / static_cast<float>(trainers_);
blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>()); blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
blas.VADD(t_latest.numel(), t_timestamp->data<float>(), blas.VADD(t_latest.numel(),
t_delta->data<float>(), t_timestamp->data<float>()); t_timestamp->data<float>(),
t_delta->data<float>(),
t_timestamp->data<float>());
} }
RpcSendDense(send_ctx, *delta_scope_); RpcSendDense(send_ctx, *delta_scope_);
VLOG(1) << "Finish Send Dense " << var_names[0] << ", table_id: " << table_id; VLOG(1) << "Finish Send Dense " << var_names[0] << ", table_id: " << table_id;
...@@ -1194,12 +1234,16 @@ void GeoCommunicator::RecvDense(const CommContext &send_ctx) { ...@@ -1194,12 +1234,16 @@ void GeoCommunicator::RecvDense(const CommContext &send_ctx) {
t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace()); t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx); auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
blas.VSUB(t_latest->numel(), t_pserver.data<float>(), t_old->data<float>(), blas.VSUB(t_latest->numel(),
t_pserver.data<float>(),
t_old->data<float>(),
t_delta->data<float>()); t_delta->data<float>());
blas.VADD(t_latest->numel(), t_latest->data<float>(), blas.VADD(t_latest->numel(),
t_delta->data<float>(), t_latest->data<float>()); t_latest->data<float>(),
blas.VCOPY(t_latest->numel(), t_pserver.data<float>(), t_delta->data<float>(),
t_old->data<float>()); t_latest->data<float>());
blas.VCOPY(
t_latest->numel(), t_pserver.data<float>(), t_old->data<float>());
} }
VLOG(1) << "Finish Recv Dense " << varnames[0] << ", table_id: " << table_id; VLOG(1) << "Finish Recv Dense " << varnames[0] << ", table_id: " << table_id;
return; return;
...@@ -1260,7 +1304,8 @@ std::vector<int64_t> GeoCommunicator::MergeSparseIds( ...@@ -1260,7 +1304,8 @@ std::vector<int64_t> GeoCommunicator::MergeSparseIds(
} }
void GeoCommunicator::SendSparse(const std::string &varname, void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, int table_id, std::vector<int64_t> &sparse_ids,
int table_id,
int ep_idx) { int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->SendSparse", platform::RecordEvent record_event("GeoCommunicator->SendSparse",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
...@@ -1276,10 +1321,12 @@ void GeoCommunicator::SendSparse(const std::string &varname, ...@@ -1276,10 +1321,12 @@ void GeoCommunicator::SendSparse(const std::string &varname,
auto *var_latest = recv_scope_->FindVar(param_name); auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_old = old_scope_->FindVar(param_name); auto *var_old = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, PADDLE_ENFORCE_EQ(var_latest->IsInitialized(),
true,
platform::errors::Unavailable( platform::errors::Unavailable(
"%s is not initialized, please check", param_name)); "%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_old->IsInitialized(), true, PADDLE_ENFORCE_EQ(var_old->IsInitialized(),
true,
platform::errors::Unavailable( platform::errors::Unavailable(
"%s is not initialized, please check", param_name)); "%s is not initialized, please check", param_name));
...@@ -1303,11 +1350,13 @@ void GeoCommunicator::SendSparse(const std::string &varname, ...@@ -1303,11 +1350,13 @@ void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<float *> push_g_vec; std::vector<float *> push_g_vec;
for (auto j = 0; j < static_cast<int>(sparse_ids.size()); ++j) { for (auto j = 0; j < static_cast<int>(sparse_ids.size()); ++j) {
blas.VSUB(dims1, t_latest.data<float>() + sparse_ids[j] * dims1, blas.VSUB(dims1,
t_latest.data<float>() + sparse_ids[j] * dims1,
t_old->data<float>() + sparse_ids[j] * dims1, t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1); t_value + j * dims1);
blas.SCAL(dims1, coefficient, t_value + j * dims1); blas.SCAL(dims1, coefficient, t_value + j * dims1);
blas.VADD(dims1, t_old->data<float>() + sparse_ids[j] * dims1, blas.VADD(dims1,
t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1, t_value + j * dims1,
t_old->data<float>() + sparse_ids[j] * dims1); t_old->data<float>() + sparse_ids[j] * dims1);
push_g_vec.push_back(t_value + j * dims1); push_g_vec.push_back(t_value + j * dims1);
...@@ -1328,8 +1377,12 @@ void GeoCommunicator::SendSparse(const std::string &varname, ...@@ -1328,8 +1377,12 @@ void GeoCommunicator::SendSparse(const std::string &varname,
--_async_call_num; --_async_call_num;
}); });
auto status = _worker_ptr->PushSparseRawGradientPartial( auto status = _worker_ptr->PushSparseRawGradientPartial(
table_id, (const uint64_t *)sparse_ids.data(), table_id,
(const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx); (const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(),
sparse_ids.size(),
closure,
ep_idx);
status.wait(); status.wait();
VLOG(1) << "Finish Send Sparse " << varname VLOG(1) << "Finish Send Sparse " << varname
...@@ -1337,7 +1390,8 @@ void GeoCommunicator::SendSparse(const std::string &varname, ...@@ -1337,7 +1390,8 @@ void GeoCommunicator::SendSparse(const std::string &varname,
return; return;
} }
void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, void GeoCommunicator::RecvSparse(const std::string &varname,
int table_id,
int ep_idx) { int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->RecvSparse", platform::RecordEvent record_event("GeoCommunicator->RecvSparse",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
...@@ -1375,8 +1429,8 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, ...@@ -1375,8 +1429,8 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
float *latest_data = t_latest->data<float>() + keys[j] * dims1; float *latest_data = t_latest->data<float>() + keys[j] * dims1;
float *old_data = t_old->data<float>() + keys[j] * dims1; float *old_data = t_old->data<float>() + keys[j] * dims1;
// pserver - old => delta // pserver - old => delta
blas.VSUB(dims1, values.data() + j * dims1, old_data, blas.VSUB(
v_delta.data() + j * dims1); dims1, values.data() + j * dims1, old_data, v_delta.data() + j * dims1);
// latest + delta => latest // latest + delta => latest
blas.VADD(dims1, latest_data, v_delta.data() + j * dims1, latest_data); blas.VADD(dims1, latest_data, v_delta.data() + j * dims1, latest_data);
// pserver => old // pserver => old
...@@ -1404,7 +1458,8 @@ void GeoCommunicator::MainThread() { ...@@ -1404,7 +1458,8 @@ void GeoCommunicator::MainThread() {
if (ctx.is_sparse) { if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
varnames.size(), 1, varnames.size(),
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables")); "sparse variables can only be merged by one variables"));
int pserver_num = static_cast<int>(ctx.epmap.size()); int pserver_num = static_cast<int>(ctx.epmap.size());
......
...@@ -63,7 +63,8 @@ template <typename T> ...@@ -63,7 +63,8 @@ template <typename T>
class BlockingQueue { class BlockingQueue {
public: public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) { explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0, PADDLE_ENFORCE_GT(capacity_,
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The capacity must be greater than 0.")); "The capacity must be greater than 0."));
} }
...@@ -149,16 +150,19 @@ class BlockingQueue { ...@@ -149,16 +150,19 @@ class BlockingQueue {
mutable std::mutex mutex_; mutable std::mutex mutex_;
}; };
template <typename T, int MajorType = Eigen::RowMajor, template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T> template <typename T>
inline void MergeVars(const std::string &var_name, inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars, const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope, bool merge_add = true) { Scope *scope,
bool merge_add = true) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
vars.empty(), true, vars.empty(),
true,
platform::errors::InvalidArgument("vector vars are empty.")); platform::errors::InvalidArgument("vector vars are empty."));
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0]; auto &var0 = vars[0];
...@@ -175,7 +179,8 @@ inline void MergeVars(const std::string &var_name, ...@@ -175,7 +179,8 @@ inline void MergeVars(const std::string &var_name,
for (auto &var : vars) { for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>(); auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var_t.dims(), dims, var_t.dims(),
dims,
platform::errors::InvalidArgument("vars should have the same dims.")); platform::errors::InvalidArgument("vars should have the same dims."));
} }
...@@ -207,14 +212,14 @@ inline void MergeVars(const std::string &var_name, ...@@ -207,14 +212,14 @@ inline void MergeVars(const std::string &var_name,
} }
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
if (merge_add) { if (merge_add) {
paddle::operators::math::scatter::MergeAdd< paddle::operators::math::scatter::
paddle::platform::CPUDeviceContext, T> MergeAdd<paddle::platform::CPUDeviceContext, T>
merge_add; merge_add;
merge_add(dev_ctx, inputs, out_slr); merge_add(dev_ctx, inputs, out_slr);
} else { } else {
paddle::operators::math::scatter::MergeAverage< paddle::operators::math::scatter::
paddle::platform::CPUDeviceContext, T> MergeAverage<paddle::platform::CPUDeviceContext, T>
merge_average; merge_average;
merge_average(dev_ctx, inputs, out_slr); merge_average(dev_ctx, inputs, out_slr);
} }
...@@ -254,23 +259,29 @@ class Communicator { ...@@ -254,23 +259,29 @@ class Communicator {
// 1. recv dense param // 1. recv dense param
virtual void RpcRecvDense(const std::vector<std::string> &varnames, virtual void RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope); int table_id,
Scope *scope);
// 2. send dense param // 2. send dense param
virtual void RpcSendDenseParam(const std::vector<std::string> &varnames, virtual void RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id, const Scope &scope); int table_id,
const Scope &scope);
// 3. send dense grad // 3. send dense grad
virtual void RpcSendDense(const CommContext &ctx, const Scope &scope); virtual void RpcSendDense(const CommContext &ctx, const Scope &scope);
// 4. send sparse grad // 4. send sparse grad
virtual void RpcSendSparse(const std::string &var_name, int table_id, virtual void RpcSendSparse(const std::string &var_name,
int table_id,
const Scope &scope); const Scope &scope);
// 5. send sparse param // 5. send sparse param
virtual void RpcSendSparseParam(const std::string &varname, int table_id, virtual void RpcSendSparseParam(const std::string &varname,
int table_id,
const Scope &scope); const Scope &scope);
// 6. recv sparse param // 6. recv sparse param
virtual void RpcRecvSparse(const std::string &varname, int table_id, virtual void RpcRecvSparse(const std::string &varname,
int table_id,
Scope *scope); Scope *scope);
// 7. send gloabl step // 7. send gloabl step
virtual void SendGlobalStep(const CommContext &ctx, int batches, virtual void SendGlobalStep(const CommContext &ctx,
int batches,
Scope *send_scope); Scope *send_scope);
virtual ~Communicator() {} virtual ~Communicator() {}
...@@ -303,7 +314,8 @@ class Communicator { ...@@ -303,7 +314,8 @@ class Communicator {
auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type); auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type);
rets.wait(); rets.wait();
int status = rets.get(); int status = rets.get();
PADDLE_ENFORCE_EQ(status, 0, PADDLE_ENFORCE_EQ(status,
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The ret status must be 0 when barrier with table")); "The ret status must be 0 when barrier with table"));
} }
...@@ -333,12 +345,19 @@ class Communicator { ...@@ -333,12 +345,19 @@ class Communicator {
template <typename T> template <typename T>
static Communicator *InitInstance( static Communicator *InitInstance(
const RpcCtxMap &send_ctx, const RecvCtxMap &recv_ctx, const RpcCtxMap &send_ctx,
const RecvCtxMap &recv_ctx,
const std::string &dist_desc, const std::string &dist_desc,
const std::vector<std::string> &host_sign_list, Scope *recv_scope, const std::vector<std::string> &host_sign_list,
Scope *recv_scope,
const std::map<std::string, std::string> &envs) { const std::map<std::string, std::string> &envs) {
std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>, send_ctx, std::call_once(init_flag_,
recv_ctx, dist_desc, host_sign_list, recv_scope, &Communicator::InitWithRpcCtx<T>,
send_ctx,
recv_ctx,
dist_desc,
host_sign_list,
recv_scope,
std::ref(envs)); std::ref(envs));
return communicator_.get(); return communicator_.get();
} }
...@@ -456,15 +475,22 @@ class AsyncCommunicator : public Communicator { ...@@ -456,15 +475,22 @@ class AsyncCommunicator : public Communicator {
void PushDensePostProcessing(); void PushDensePostProcessing();
void PullSparseToTensorSync( void PullSparseToTensorSync(
const uint64_t table_id, int fea_dim, uint64_t padding_id, const uint64_t table_id,
platform::Place place, bool is_training, int fea_dim,
uint64_t padding_id,
platform::Place place,
bool is_training,
std::vector<const framework::LoDTensor *> *inputs, // NOLINT std::vector<const framework::LoDTensor *> *inputs, // NOLINT
std::vector<framework::LoDTensor *> *outputs); // NOLINT std::vector<framework::LoDTensor *> *outputs); // NOLINT
void PushSparseFromTensorAsync( void PushSparseFromTensorAsync(
const uint64_t table_id, int fea_dim, uint64_t padding_id, const uint64_t table_id,
platform::Place place, std::vector<const framework::LoDTensor *> *inputs, int fea_dim,
const framework::LoDTensor *shows, const framework::LoDTensor *clicks, uint64_t padding_id,
platform::Place place,
std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows,
const framework::LoDTensor *clicks,
std::vector<framework::LoDTensor *> *outputs); std::vector<framework::LoDTensor *> *outputs);
protected: protected:
...@@ -585,7 +611,8 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -585,7 +611,8 @@ class GeoCommunicator : public AsyncCommunicator {
std::vector<int64_t> MergeSparseIds(const std::string &varname); std::vector<int64_t> MergeSparseIds(const std::string &varname);
void SendSparse(const std::string &varname, void SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, // NOLINT std::vector<int64_t> &sparse_ids, // NOLINT
int table_id, int ep_idx); int table_id,
int ep_idx);
void RecvSparse(const std::string &varname, int table_id, int ep_idx); void RecvSparse(const std::string &varname, int table_id, int ep_idx);
void MainThread() override; void MainThread() override;
...@@ -628,8 +655,9 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -628,8 +655,9 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver // parameter on pserver
std::shared_ptr<Scope> pserver_scope_; std::shared_ptr<Scope> pserver_scope_;
std::unordered_map<std::string, paddle::framework::Channel< std::unordered_map<
std::shared_ptr<std::vector<int64_t>>>> std::string,
paddle::framework::Channel<std::shared_ptr<std::vector<int64_t>>>>
sparse_id_queues_; sparse_id_queues_;
}; };
......
...@@ -25,13 +25,18 @@ namespace distributed { ...@@ -25,13 +25,18 @@ namespace distributed {
struct CommContext { struct CommContext {
CommContext() = default; CommContext() = default;
CommContext(const std::string &name, const std::vector<std::string> &names, CommContext(const std::string &name,
const std::vector<std::string> &names,
const std::vector<std::string> &emap, const std::vector<std::string> &emap,
const std::vector<int64_t> &sections, const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names, int id, const std::vector<std::string> &origin_names,
bool merge_add_ = true, bool is_sparse_ = true, int id,
bool is_distributed_ = false, int table_id_ = -1, bool merge_add_ = true,
bool is_tensor_table_ = false, bool is_datanorm_table_ = false, bool is_sparse_ = true,
bool is_distributed_ = false,
int table_id_ = -1,
bool is_tensor_table_ = false,
bool is_datanorm_table_ = false,
int64_t program_id_ = -1) int64_t program_id_ = -1)
: var_name(name), : var_name(name),
splited_varnames(names), splited_varnames(names),
......
...@@ -86,8 +86,10 @@ struct PSHost { ...@@ -86,8 +86,10 @@ struct PSHost {
rank = std::stoi(endpoint_info[2]); rank = std::stoi(endpoint_info[2]);
} }
void StringSplit(const std::string &str, char sep, void StringSplit(const std::string &str,
std::vector<std::string> *pieces, bool ignore_null = true) { char sep,
std::vector<std::string> *pieces,
bool ignore_null = true) {
pieces->clear(); pieces->clear();
if (str.empty()) { if (str.empty()) {
if (!ignore_null) { if (!ignore_null) {
...@@ -130,13 +132,15 @@ class PSEnvironment { ...@@ -130,13 +132,15 @@ class PSEnvironment {
} }
virtual uint64_t GetLocalHostSign() { return 0; } virtual uint64_t GetLocalHostSign() { return 0; }
virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; } virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
virtual int32_t RegistePsServer(const std::string &ip, uint32_t port, virtual int32_t RegistePsServer(const std::string &ip,
uint32_t port,
int32_t rank) { int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set); return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set);
} }
virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; } virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; }
virtual int32_t RegistePsClient(const std::string &ip, uint32_t port, virtual int32_t RegistePsClient(const std::string &ip,
uint32_t port,
int32_t rank) { int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set); return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
} }
...@@ -167,7 +171,9 @@ class PSEnvironment { ...@@ -167,7 +171,9 @@ class PSEnvironment {
protected: protected:
//注册一个host // NOLINT //注册一个host // NOLINT
virtual int32_t RegistePsHost( virtual int32_t RegistePsHost(
const std::string &ip, uint32_t port, int32_t rank, const std::string &ip,
uint32_t port,
int32_t rank,
std::vector<PSHost> &host_list, // NOLINT std::vector<PSHost> &host_list, // NOLINT
std::unordered_set<uint64_t> &sign_set) { // NOLINT std::unordered_set<uint64_t> &sign_set) { // NOLINT
PSHost host; PSHost host;
...@@ -209,7 +215,8 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -209,7 +215,8 @@ class PaddlePSEnvironment : public PSEnvironment {
} }
} }
std::sort( std::sort(
_ps_server_list.begin(), _ps_server_list.end(), _ps_server_list.begin(),
_ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0; return 0;
} }
...@@ -227,7 +234,8 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -227,7 +234,8 @@ class PaddlePSEnvironment : public PSEnvironment {
} }
} }
std::sort( std::sort(
_ps_server_list.begin(), _ps_server_list.end(), _ps_server_list.begin(),
_ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0; return 0;
} }
...@@ -244,7 +252,8 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -244,7 +252,8 @@ class PaddlePSEnvironment : public PSEnvironment {
} }
} }
std::sort( std::sort(
_ps_client_list.begin(), _ps_client_list.end(), _ps_client_list.begin(),
_ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0; return 0;
} }
...@@ -262,7 +271,8 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -262,7 +271,8 @@ class PaddlePSEnvironment : public PSEnvironment {
} }
} }
std::sort( std::sort(
_ps_client_list.begin(), _ps_client_list.end(), _ps_client_list.begin(),
_ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
VLOG(1) << "env.set_ps_clients done\n"; VLOG(1) << "env.set_ps_clients done\n";
return 0; return 0;
......
...@@ -40,7 +40,8 @@ class GraphPsService_Stub : public PsService_Stub { ...@@ -40,7 +40,8 @@ class GraphPsService_Stub : public PsService_Stub {
public: public:
GraphPsService_Stub(::google::protobuf::RpcChannel* channel, GraphPsService_Stub(::google::protobuf::RpcChannel* channel,
::google::protobuf::RpcChannel* local_channel = NULL, ::google::protobuf::RpcChannel* local_channel = NULL,
GraphBrpcService* service = NULL, int thread_num = 1) GraphBrpcService* service = NULL,
int thread_num = 1)
: PsService_Stub(channel) { : PsService_Stub(channel) {
this->local_channel = local_channel; this->local_channel = local_channel;
this->graph_service = service; this->graph_service = service;
...@@ -64,34 +65,50 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -64,34 +65,50 @@ class GraphBrpcClient : public BrpcPsClient {
virtual ~GraphBrpcClient() {} virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them // given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors( virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, int idx, std::vector<int64_t> node_ids, uint32_t table_id,
int sample_size, std::vector<std::vector<int64_t>>& res, int idx,
std::vector<std::vector<float>>& res_weight, bool need_weight, std::vector<int64_t> node_ids,
int sample_size,
std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight,
bool need_weight,
int server_index = -1); int server_index = -1);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id, int type_id, virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int idx, int server_index, int type_id,
int start, int size, int step, int idx,
int server_index,
int start,
int size,
int step,
std::vector<FeatureNode>& res); std::vector<FeatureNode>& res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id, virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int type_id, int idx, int type_id,
int idx,
int server_index, int server_index,
int sample_size, int sample_size,
std::vector<int64_t>& ids); std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat( virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id, int idx, const std::vector<int64_t>& node_ids, const uint32_t& table_id,
int idx,
const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names, const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res); std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat( virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, int idx, const std::vector<int64_t>& node_ids, const uint32_t& table_id,
int idx,
const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names, const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features); const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id, int type_id, virtual std::future<int32_t> clear_nodes(uint32_t table_id,
int type_id,
int idx); int idx);
virtual std::future<int32_t> add_graph_node( virtual std::future<int32_t> add_graph_node(
uint32_t table_id, int idx, std::vector<int64_t>& node_id_list, uint32_t table_id,
int idx,
std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list); std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> remove_graph_node( virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, int idx_, std::vector<int64_t>& node_id_list); uint32_t table_id, int idx_, std::vector<int64_t>& node_id_list);
...@@ -107,8 +124,8 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -107,8 +124,8 @@ class GraphBrpcClient : public BrpcPsClient {
} }
GraphPsService_Stub getServiceStub(::google::protobuf::RpcChannel* channel, GraphPsService_Stub getServiceStub(::google::protobuf::RpcChannel* channel,
int thread_num = 1) { int thread_num = 1) {
return GraphPsService_Stub(channel, local_channel, graph_service, return GraphPsService_Stub(
thread_num); channel, local_channel, graph_service, thread_num);
} }
private: private:
......
此差异已折叠。
此差异已折叠。
...@@ -117,7 +117,8 @@ void HeterServer::WaitServerReady() { ...@@ -117,7 +117,8 @@ void HeterServer::WaitServerReady() {
} }
int SendAndRecvVariableHandler::SaveInSwitchWithShard( int SendAndRecvVariableHandler::SaveInSwitchWithShard(
const MultiVarMsg* request, PsResponseMessage* response, const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl) { brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithShard"; VLOG(4) << "entering SaveInSwitchWithShard";
int32_t group_id = request->group_id(); int32_t group_id = request->group_id();
...@@ -174,7 +175,8 @@ int SendAndRecvVariableHandler::QueryInSwitchWithShard( ...@@ -174,7 +175,8 @@ int SendAndRecvVariableHandler::QueryInSwitchWithShard(
} }
int SendAndRecvVariableHandler::SaveInSwitchWithScope( int SendAndRecvVariableHandler::SaveInSwitchWithScope(
const MultiVarMsg* request, PsResponseMessage* response, const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl) { brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithScope"; VLOG(4) << "entering SaveInSwitchWithScope";
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
...@@ -201,8 +203,8 @@ int SendAndRecvVariableHandler::SaveInSwitchWithScope( ...@@ -201,8 +203,8 @@ int SendAndRecvVariableHandler::SaveInSwitchWithScope(
WaitForVarsConsumed(0, var_name); WaitForVarsConsumed(0, var_name);
} }
auto& request_io_buffer = cntl->request_attachment(); auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(*request, &request_io_buffer, distributed::DeserializeFromMultiVarMsgAndIOBuf(
cpu_dev_ctx, local_scope); *request, &request_io_buffer, cpu_dev_ctx, local_scope);
lk.unlock(); lk.unlock();
for (auto var_name : send_var_names) { for (auto var_name : send_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_); std::unique_lock<std::mutex> lk(scope_mutex_);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册