未验证 提交 1d1e5484 编写于 作者: X Xing-lil 提交者: GitHub

Update gloo in dygraph (#55537)

* update broadcast gloo in dygraph

* update

* update reduce gloo in dygraph

* update reduce gloo in dygraph

* update

* update allreduce allgather

* update all

* update

* update

* update
上级 982e0a9d
...@@ -24,16 +24,12 @@ ...@@ -24,16 +24,12 @@
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <gloo/broadcast.h>
#include <gloo/gather.h>
#include <gloo/reduce.h> #include <gloo/reduce.h>
#include <gloo/scatter.h>
#include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/gloo_send_recv.h"
#include "paddle/fluid/distributed/collective/process_group_gloo.h" #include "paddle/fluid/distributed/collective/process_group_gloo.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -99,28 +95,6 @@ namespace distributed { ...@@ -99,28 +95,6 @@ namespace distributed {
} }
#endif #endif
typedef void (*reduce_func)(void*, const void*, const void*, size_t);
template <typename T>
reduce_func get_function(const ReduceOp& r) {
switch (r) {
case ReduceOp::SUM:
return reduce_func(&::gloo::sum<T>);
case ReduceOp::PRODUCT:
return reduce_func(&::gloo::product<T>);
case ReduceOp::MIN:
return reduce_func(&::gloo::min<T>);
case ReduceOp::MAX:
return reduce_func(&::gloo::max<T>);
case ReduceOp::AVG:
VLOG(0) << "Error: Unsupported ReduceOp::AVG.";
exit(-1);
}
VLOG(0) << "Error: Unknown ReduceOp.";
exit(-1);
}
template <typename T> template <typename T>
T* get_data(phi::DenseTensor& tensor) { // NOLINT T* get_data(phi::DenseTensor& tensor) { // NOLINT
return reinterpret_cast<T*>(tensor.data()); return reinterpret_cast<T*>(tensor.data());
...@@ -188,21 +162,19 @@ ProcessGroupGloo::ProcessGroupGloo( ...@@ -188,21 +162,19 @@ ProcessGroupGloo::ProcessGroupGloo(
_tag(0), _tag(0),
_store(new GlooStore(store)) { _store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size); _context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
auto prefix_store = _context->connectFullMesh(*_store, options->device);
::gloo::rendezvous::PrefixStore(std::to_string(gid), *_store);
_context->connectFullMesh(prefix_store, options->device);
} }
class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
BroadcastGlooTask(const std::shared_ptr<gloo::Context>& context, BroadcastGlooTask(phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
int rank, int rank,
int root, int root,
uint32_t tag) uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST), : ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST),
_context(context), _comm_context(comm_context),
_root(root), _root(root),
_inputs(inputs), _inputs(inputs),
_outputs(outputs), _outputs(outputs),
...@@ -211,22 +183,14 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -211,22 +183,14 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_broadcast(_inputs[0], _outputs[0]); } void Run() override { _do_broadcast(_inputs[0], _outputs[0]); }
private: private:
std::shared_ptr<gloo::Context> _context; phi::distributed::GlooCommContext* _comm_context;
const int _root; const int _root;
std::vector<phi::DenseTensor> _inputs{}; std::vector<phi::DenseTensor> _inputs{};
std::vector<phi::DenseTensor> _outputs{}; std::vector<phi::DenseTensor> _outputs{};
const uint32_t _tag; const uint32_t _tag;
void _do_broadcast(phi::DenseTensor& in, phi::DenseTensor& out) { // NOLINT void _do_broadcast(phi::DenseTensor& in, phi::DenseTensor& out) { // NOLINT
gloo::BroadcastOptions opts(_context); _comm_context->Broadcast(&(out), in, _root, _tag);
const auto& dtype = in.dtype();
if (rank_ == _root) {
GENERATE_FUNC(dtype, set_input, opts, in);
}
GENERATE_FUNC(dtype, set_output, opts, out);
opts.setRoot(_root);
opts.setTag(_tag);
gloo::broadcast(opts);
} }
}; };
...@@ -256,22 +220,22 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast( ...@@ -256,22 +220,22 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
auto root = opts.source_rank; auto root = opts.source_rank;
std::unique_ptr<BroadcastGlooTask> task; std::unique_ptr<BroadcastGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto comm_context = this->GetCommContext();
task = std::make_unique<BroadcastGlooTask>( task = std::make_unique<BroadcastGlooTask>(
context, inputs, outputs, rank_, root, tag); comm_context, inputs, outputs, rank_, root, tag);
task->Run(); task->Run();
return task; return task;
} }
class SendGlooTask : public ProcessGroupGloo::GlooTask { class SendGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
SendGlooTask(const std::shared_ptr<gloo::Context>& context, SendGlooTask(phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>* inputs, std::vector<phi::DenseTensor>* inputs,
int rank, int rank,
int dst_rank, int dst_rank,
uint32_t tag) uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, *inputs, CommType::SEND), : ProcessGroupGloo::GlooTask(rank, *inputs, CommType::SEND),
_context(context), _comm_context(comm_context),
_inputs(*inputs), _inputs(*inputs),
_dst(dst_rank), _dst(dst_rank),
_tag(tag) {} _tag(tag) {}
...@@ -279,20 +243,13 @@ class SendGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -279,20 +243,13 @@ class SendGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_send(_inputs); } void Run() override { _do_send(_inputs); }
private: private:
std::shared_ptr<gloo::Context> _context; phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs; std::vector<phi::DenseTensor> _inputs;
int _dst; int _dst;
uint32_t _tag; uint32_t _tag;
void _do_send(std::vector<phi::DenseTensor>& in) { // NOLINT void _do_send(std::vector<phi::DenseTensor>& in) { // NOLINT
SendRecvOptions opts(_context); _comm_context->Send(in[0], _dst, _tag);
const auto& dtype = in[0].dtype();
GENERATE_FUNC(dtype, set_input, opts, in[0]);
opts.setSrc(_context.get()->rank);
opts.setDst(_dst);
opts.setTag(_tag);
send_recv(&opts);
} }
}; };
...@@ -306,8 +263,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Send( ...@@ -306,8 +263,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Send(
std::vector<phi::DenseTensor>& inputs, int dst_rank) { std::vector<phi::DenseTensor>& inputs, int dst_rank) {
std::unique_ptr<SendGlooTask> task; std::unique_ptr<SendGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto comm_context = this->GetCommContext();
task = std::make_unique<SendGlooTask>(context, &inputs, rank_, dst_rank, tag); task = std::make_unique<SendGlooTask>(
comm_context, &inputs, rank_, dst_rank, tag);
task->Run(); task->Run();
return task; return task;
...@@ -315,13 +273,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Send( ...@@ -315,13 +273,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Send(
class RecvGlooTask : public ProcessGroupGloo::GlooTask { class RecvGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
RecvGlooTask(const std::shared_ptr<gloo::Context>& context, RecvGlooTask(phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>* outputs, std::vector<phi::DenseTensor>* outputs,
int rank, int rank,
int src_rank, int src_rank,
uint32_t tag) uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, *outputs, CommType::RECV), : ProcessGroupGloo::GlooTask(rank, *outputs, CommType::RECV),
_context(context), _comm_context(comm_context),
_outputs(*outputs), _outputs(*outputs),
_src(src_rank), _src(src_rank),
_tag(tag) {} _tag(tag) {}
...@@ -329,20 +287,13 @@ class RecvGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -329,20 +287,13 @@ class RecvGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_recv(_outputs); } void Run() override { _do_recv(_outputs); }
private: private:
std::shared_ptr<gloo::Context> _context; phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _outputs; std::vector<phi::DenseTensor> _outputs;
const int _src; const int _src;
const uint32_t _tag; const uint32_t _tag;
void _do_recv(std::vector<phi::DenseTensor>& out) { // NOLINT void _do_recv(std::vector<phi::DenseTensor>& out) { // NOLINT
SendRecvOptions opts(_context); _comm_context->Recv(&(out[0]), _src, _tag);
const auto& dtype = out[0].dtype();
GENERATE_FUNC(dtype, set_output, opts, out[0]);
opts.setSrc(_src);
opts.setDst(_context.get()->rank);
opts.setTag(_tag);
send_recv(&opts);
} }
}; };
...@@ -356,10 +307,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Recv( ...@@ -356,10 +307,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Recv(
std::vector<phi::DenseTensor>& outputs, int src_rank) { std::vector<phi::DenseTensor>& outputs, int src_rank) {
std::unique_ptr<RecvGlooTask> task; std::unique_ptr<RecvGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto comm_context = this->GetCommContext();
task = task = std::make_unique<RecvGlooTask>(
std::make_unique<RecvGlooTask>(context, &outputs, rank_, src_rank, tag); comm_context, &outputs, rank_, src_rank, tag);
task->Run(); task->Run();
return task; return task;
} }
...@@ -367,13 +318,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Recv( ...@@ -367,13 +318,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Recv(
class AllreduceGlooTask : public ProcessGroupGloo::GlooTask { class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
AllreduceGlooTask(int rank, AllreduceGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context, phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op, ReduceOp reduce_op,
uint32_t tag) uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE), : ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE),
_context(context), _comm_context(comm_context),
_inputs(inputs), _inputs(inputs),
_outputs(outputs), _outputs(outputs),
_reduce_op(reduce_op), _reduce_op(reduce_op),
...@@ -382,34 +333,16 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -382,34 +333,16 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_allreduce(_inputs, _outputs); } void Run() override { _do_allreduce(_inputs, _outputs); }
private: private:
std::shared_ptr<gloo::Context> _context; phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs; std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs; std::vector<phi::DenseTensor> _outputs;
const ReduceOp _reduce_op; const ReduceOp _reduce_op;
uint32_t _tag; uint32_t _tag;
gloo::AllreduceOptions::Func _get_function(const phi::DataType type,
const ReduceOp op) {
gloo::AllreduceOptions::Func fn;
GENERATE_FUNC(type, _get_function_impl, fn, op);
return fn;
}
template <typename T>
void _get_function_impl(gloo::AllreduceOptions::Func& fn, // NOLINT
const ReduceOp op) {
fn = get_function<T>(op);
}
void _do_allreduce(std::vector<phi::DenseTensor>& ins, // NOLINT void _do_allreduce(std::vector<phi::DenseTensor>& ins, // NOLINT
std::vector<phi::DenseTensor>& outs) { // NOLINT std::vector<phi::DenseTensor>& outs) { // NOLINT
const auto& dtype = ins[0].dtype(); _comm_context->AllReduce(
gloo::AllreduceOptions opts(_context); &(outs[0]), ins[0], static_cast<int>(_reduce_op), _tag);
GENERATE_FUNC(dtype, set_inputs, opts, ins);
GENERATE_FUNC(dtype, set_outputs, opts, outs);
opts.setReduceFunction(_get_function(dtype, _reduce_op));
opts.setTag(_tag);
gloo::allreduce(opts);
} }
}; };
...@@ -437,36 +370,33 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce( ...@@ -437,36 +370,33 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
bool sync_op) { bool sync_op) {
auto tag = next_tag(); auto tag = next_tag();
std::shared_ptr<GlooTask> task; std::shared_ptr<GlooTask> task;
auto context = get_context(); auto comm_context = this->GetCommContext();
task = std::make_shared<AllreduceGlooTask>( task = std::make_shared<AllreduceGlooTask>(
rank_, context, inputs, outputs, opts.reduce_op, tag); rank_, comm_context, inputs, outputs, opts.reduce_op, tag);
task->Run(); task->Run();
return task; return task;
} }
class BarrierGlooTask : public ProcessGroupGloo::GlooTask { class BarrierGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
BarrierGlooTask(int rank, const std::shared_ptr<gloo::Context>& context) BarrierGlooTask(int rank, phi::distributed::GlooCommContext* comm_context)
: ProcessGroupGloo::GlooTask( : ProcessGroupGloo::GlooTask(
rank, std::vector<phi::DenseTensor>{}, CommType::BARRIER), rank, std::vector<phi::DenseTensor>{}, CommType::BARRIER),
_context(context) {} _comm_context(comm_context) {}
void Run() override { _do_barrier(); } void Run() override { _do_barrier(); }
private: private:
std::shared_ptr<gloo::Context> _context; phi::distributed::GlooCommContext* _comm_context;
void _do_barrier() { void _do_barrier() { _comm_context->Barrier(); }
gloo::BarrierOptions opts(_context);
gloo::barrier(opts);
}
}; };
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier(
const BarrierOptions& opts) { const BarrierOptions& opts) {
std::shared_ptr<BarrierGlooTask> task; std::shared_ptr<BarrierGlooTask> task;
auto context = get_context(); auto comm_context = this->GetCommContext();
task = std::make_shared<BarrierGlooTask>(rank_, context); task = std::make_shared<BarrierGlooTask>(rank_, comm_context);
task->Run(); task->Run();
return task; return task;
} }
...@@ -474,12 +404,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier( ...@@ -474,12 +404,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier(
class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
AllgatherGlooTask(int rank, AllgatherGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context, phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
uint32_t tag) uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLGATHER), : ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLGATHER),
_context(context), _comm_context(comm_context),
_inputs(inputs), _inputs(inputs),
_outputs(outputs), _outputs(outputs),
_tag(tag) {} _tag(tag) {}
...@@ -487,19 +417,14 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -487,19 +417,14 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_allgather(_inputs, _outputs); } void Run() override { _do_allgather(_inputs, _outputs); }
private: private:
std::shared_ptr<gloo::Context> _context; phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs; std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs; std::vector<phi::DenseTensor> _outputs;
uint32_t _tag; uint32_t _tag;
void _do_allgather(std::vector<phi::DenseTensor>& in, // NOLINT void _do_allgather(std::vector<phi::DenseTensor>& in, // NOLINT
std::vector<phi::DenseTensor>& out) { // NOLINT std::vector<phi::DenseTensor>& out) { // NOLINT
const auto& dtype = in[0].dtype(); _comm_context->AllGather(&(out[0]), in[0], _tag);
gloo::AllgatherOptions opts(_context);
GENERATE_FUNC(dtype, set_input, opts, in[0]);
GENERATE_FUNC(dtype, set_output, opts, out[0]);
opts.setTag(_tag);
gloo::allgather(opts);
} }
}; };
...@@ -526,9 +451,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather( ...@@ -526,9 +451,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
bool sync_op) { bool sync_op) {
std::shared_ptr<AllgatherGlooTask> task; std::shared_ptr<AllgatherGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto comm_context = this->GetCommContext();
task = std::make_shared<AllgatherGlooTask>( task = std::make_shared<AllgatherGlooTask>(
rank_, context, in_tensors, out_tensors, tag); rank_, comm_context, in_tensors, out_tensors, tag);
task->Run(); task->Run();
return task; return task;
} }
...@@ -536,14 +461,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather( ...@@ -536,14 +461,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
class ReduceGlooTask : public ProcessGroupGloo::GlooTask { class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
ReduceGlooTask(int rank, ReduceGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context, phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op, ReduceOp reduce_op,
int dst, int dst,
uint32_t tag) uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::REDUCE), : ProcessGroupGloo::GlooTask(rank, inputs, CommType::REDUCE),
_context(context), _comm_context(comm_context),
_inputs(inputs), _inputs(inputs),
_outputs(outputs), _outputs(outputs),
_reduce_op(reduce_op), _reduce_op(reduce_op),
...@@ -553,37 +478,18 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -553,37 +478,18 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_reduce(_inputs, _outputs, _dst); } void Run() override { _do_reduce(_inputs, _outputs, _dst); }
private: private:
std::shared_ptr<gloo::Context> _context; phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs; std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs; std::vector<phi::DenseTensor> _outputs;
const ReduceOp _reduce_op; const ReduceOp _reduce_op;
int _dst; int _dst;
uint32_t _tag; uint32_t _tag;
gloo::ReduceOptions::Func _get_function(const phi::DataType type,
const ReduceOp op) {
gloo::ReduceOptions::Func fn;
GENERATE_FUNC(type, _get_function_impl, fn, op);
return fn;
}
template <typename T>
void _get_function_impl(gloo::ReduceOptions::Func& fn, // NOLINT
const ReduceOp op) {
fn = get_function<T>(op);
}
void _do_reduce(std::vector<phi::DenseTensor>& inputs, // NOLINT void _do_reduce(std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
int dst) { int dst) {
const auto& dtype = inputs[0].dtype(); _comm_context->Reduce(
gloo::ReduceOptions opts(_context); &(outputs[0]), inputs[0], static_cast<int>(_reduce_op), _dst, _tag);
GENERATE_FUNC(dtype, set_input, opts, inputs[0]);
GENERATE_FUNC(dtype, set_output, opts, outputs[0]);
opts.setReduceFunction(_get_function(dtype, _reduce_op));
opts.setTag(_tag);
opts.setRoot(dst);
gloo::reduce(opts);
} }
}; };
...@@ -595,11 +501,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce( ...@@ -595,11 +501,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
) { ) {
std::shared_ptr<ReduceGlooTask> task; std::shared_ptr<ReduceGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto comm_context = this->GetCommContext();
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor}; std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
task = std::make_shared<ReduceGlooTask>(rank_, task = std::make_shared<ReduceGlooTask>(rank_,
context, comm_context,
in_wrapper, in_wrapper,
out_wrapper, out_wrapper,
opts.reduce_op, opts.reduce_op,
...@@ -619,14 +525,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce( ...@@ -619,14 +525,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
class ScatterGlooTask : public ProcessGroupGloo::GlooTask { class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
ScatterGlooTask(int rank, ScatterGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context, phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
int src, int src,
int size, int size,
uint32_t tag) uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::SCATTER), : ProcessGroupGloo::GlooTask(rank, inputs, CommType::SCATTER),
_context(context), _comm_context(comm_context),
_inputs(inputs), _inputs(inputs),
_outputs(outputs), _outputs(outputs),
_src(src), _src(src),
...@@ -636,7 +542,7 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -636,7 +542,7 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_scatter(_inputs, _outputs, _src); } void Run() override { _do_scatter(_inputs, _outputs, _src); }
private: private:
std::shared_ptr<gloo::Context> _context; phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs; std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs; std::vector<phi::DenseTensor> _outputs;
int _src; int _src;
...@@ -646,15 +552,7 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -646,15 +552,7 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
void _do_scatter(std::vector<phi::DenseTensor>& in, // NOLINT void _do_scatter(std::vector<phi::DenseTensor>& in, // NOLINT
std::vector<phi::DenseTensor>& out, // NOLINT std::vector<phi::DenseTensor>& out, // NOLINT
int src) { int src) {
const auto& dtype = in[0].dtype(); _comm_context->Scatter(&(out[0]), in[0], _src, _size, _tag);
gloo::ScatterOptions opts(_context);
if (rank_ == src) {
GENERATE_FUNC(dtype, set_inputs_for_scatter, opts, in[0], _size);
}
GENERATE_FUNC(dtype, set_output, opts, out[0]);
opts.setRoot(src);
opts.setTag(_tag);
gloo::scatter(opts);
} }
}; };
...@@ -665,11 +563,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter( ...@@ -665,11 +563,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
bool sync_op) { bool sync_op) {
std::shared_ptr<ScatterGlooTask> task; std::shared_ptr<ScatterGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto comm_context = this->GetCommContext();
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor}; std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
task = std::make_shared<ScatterGlooTask>( task = std::make_shared<ScatterGlooTask>(
rank_, context, in_wrapper, out_wrapper, opts.root_rank, size_, tag); rank_, comm_context, in_wrapper, out_wrapper, opts.root_rank, size_, tag);
task->Run(); task->Run();
return task; return task;
} }
...@@ -684,13 +582,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter( ...@@ -684,13 +582,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
class GatherGlooTask : public ProcessGroupGloo::GlooTask { class GatherGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
GatherGlooTask(int rank, GatherGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context, phi::distributed::GlooCommContext* comm_context,
const phi::DenseTensor& input, // NOLINT const phi::DenseTensor& input, // NOLINT
phi::DenseTensor* output, // NOLINT phi::DenseTensor* output, // NOLINT
int src, int src,
uint32_t tag) uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, {input}, CommType::GATHER), : ProcessGroupGloo::GlooTask(rank, {input}, CommType::GATHER),
_context(context), _comm_context(comm_context),
_input(input), _input(input),
_output(*output), _output(*output),
_src(src), _src(src),
...@@ -699,7 +597,7 @@ class GatherGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -699,7 +597,7 @@ class GatherGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_gather(_input, _output, _src); } void Run() override { _do_gather(_input, _output, _src); }
private: private:
std::shared_ptr<gloo::Context> _context; phi::distributed::GlooCommContext* _comm_context;
phi::DenseTensor _input; phi::DenseTensor _input;
phi::DenseTensor _output; phi::DenseTensor _output;
int _src; int _src;
...@@ -708,16 +606,7 @@ class GatherGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -708,16 +606,7 @@ class GatherGlooTask : public ProcessGroupGloo::GlooTask {
void _do_gather(phi::DenseTensor& in, // NOLINT void _do_gather(phi::DenseTensor& in, // NOLINT
phi::DenseTensor& out, // NOLINT phi::DenseTensor& out, // NOLINT
int src) { int src) {
const auto& dtype = in.dtype(); _comm_context->Gather(&(out), in, src, _tag);
gloo::GatherOptions opts(_context);
if (rank_ == src) {
GENERATE_FUNC(dtype, set_output, opts, out);
}
GENERATE_FUNC(dtype, set_input, opts, in);
opts.setRoot(src);
opts.setTag(_tag);
gloo::gather(opts);
} }
}; };
...@@ -733,9 +622,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Gather( ...@@ -733,9 +622,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Gather(
platform::errors::InvalidArgument("Gloo cannot use use_calc_stream.")); platform::errors::InvalidArgument("Gloo cannot use use_calc_stream."));
std::shared_ptr<GatherGlooTask> task; std::shared_ptr<GatherGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto comm_context = this->GetCommContext();
task = std::make_shared<GatherGlooTask>( task = std::make_shared<GatherGlooTask>(
rank_, context, in_tensor, out_tensor, opts.root_rank, tag); rank_, comm_context, in_tensor, out_tensor, opts.root_rank, tag);
task->Run(); task->Run();
return task; return task;
} }
...@@ -804,11 +693,24 @@ std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo( ...@@ -804,11 +693,24 @@ std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo(
} else { } else {
opts->device = ProcessGroupGloo::createDefaultDevice(); opts->device = ProcessGroupGloo::createDefaultDevice();
} }
phi::distributed::CommContextManager::CreateGlooCommContext(
store, gid, rank, size);
auto process_group = auto process_group =
std::make_shared<ProcessGroupGloo>(store, rank, size, gid, opts); std::make_shared<ProcessGroupGloo>(store, rank, size, gid, opts);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group); ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group; return process_group;
} }
phi::distributed::GlooCommContext* ProcessGroupGloo::GetCommContext() {
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
auto comm_context = static_cast<phi::distributed::GlooCommContext*>(
comm_context_manager.Get(this->gid_));
PADDLE_ENFORCE_NE(comm_context,
nullptr,
phi::errors::Unavailable("GlooCommContext is nullptr"));
return comm_context;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h" #include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "paddle/phi/core/distributed/store/store.h" #include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/distributed/store/tcp_store.h" #include "paddle/phi/core/distributed/store/tcp_store.h"
...@@ -225,6 +226,8 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { ...@@ -225,6 +226,8 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
return GetDeviceContext(place); return GetDeviceContext(place);
} }
phi::distributed::GlooCommContext* GetCommContext();
// Helper functions for Gloo. // Helper functions for Gloo.
static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
const std::string& hostname); const std::string& hostname);
......
...@@ -24,7 +24,7 @@ namespace paddle { ...@@ -24,7 +24,7 @@ namespace paddle {
namespace distributed { namespace distributed {
// TODO(shenliang03): To support AVG for reduce // TODO(shenliang03): To support AVG for reduce
enum class ReduceOp : std::uint8_t { SUM = 0, AVG, MAX, MIN, PRODUCT }; enum class ReduceOp : std::uint8_t { SUM = 0, MAX, MIN, PRODUCT, AVG };
struct AllreduceOptions { struct AllreduceOptions {
ReduceOp reduce_op = ReduceOp::SUM; ReduceOp reduce_op = ReduceOp::SUM;
......
...@@ -17,8 +17,11 @@ ...@@ -17,8 +17,11 @@
#include <gloo/allgather.h> #include <gloo/allgather.h>
#include <gloo/allreduce.h> #include <gloo/allreduce.h>
#include <gloo/barrier.h>
#include <gloo/broadcast.h> #include <gloo/broadcast.h>
#include <gloo/gather.h>
#include <gloo/reduce.h> #include <gloo/reduce.h>
#include <gloo/scatter.h>
#include <gloo/types.h> #include <gloo/types.h>
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
...@@ -41,7 +44,8 @@ GlooCommContext::GlooCommContext( ...@@ -41,7 +44,8 @@ GlooCommContext::GlooCommContext(
void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor, void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int root) { int root,
uint32_t tag) {
// gloo only uses CPU now // gloo only uses CPU now
CommStaticCheck::SameShape(*out_tensor, CommStaticCheck::SameShape(*out_tensor,
in_tensor, in_tensor,
...@@ -56,15 +60,18 @@ void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor, ...@@ -56,15 +60,18 @@ void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
} }
opts.setRoot(root); opts.setRoot(root);
opts.setTag(tag);
gloo::broadcast(opts); gloo::broadcast(opts);
} }
void GlooCommContext::AllGather(phi::DenseTensor* out_tensor, void GlooCommContext::AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor) { const phi::DenseTensor& in_tensor,
uint32_t tag) {
// gloo only uses CPU now // gloo only uses CPU now
gloo::AllgatherOptions opts(gloo_context_); gloo::AllgatherOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype(); const auto& dtype = in_tensor.dtype();
opts.setTag(tag);
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
gloo::allgather(opts); gloo::allgather(opts);
...@@ -72,8 +79,10 @@ void GlooCommContext::AllGather(phi::DenseTensor* out_tensor, ...@@ -72,8 +79,10 @@ void GlooCommContext::AllGather(phi::DenseTensor* out_tensor,
void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor, void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int reduce_type) { int reduce_type,
uint32_t tag) {
gloo::AllreduceOptions opts(gloo_context_); gloo::AllreduceOptions opts(gloo_context_);
opts.setTag(tag);
const auto& dtype = in_tensor.dtype(); const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
...@@ -84,9 +93,11 @@ void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor, ...@@ -84,9 +93,11 @@ void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor,
void GlooCommContext::Reduce(phi::DenseTensor* out_tensor, void GlooCommContext::Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int reduce_type, int reduce_type,
int root) { int root,
uint32_t tag) {
gloo::ReduceOptions opts(gloo_context_); gloo::ReduceOptions opts(gloo_context_);
opts.setRoot(root); opts.setRoot(root);
opts.setTag(tag);
const auto& dtype = in_tensor.dtype(); const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
...@@ -94,5 +105,65 @@ void GlooCommContext::Reduce(phi::DenseTensor* out_tensor, ...@@ -94,5 +105,65 @@ void GlooCommContext::Reduce(phi::DenseTensor* out_tensor,
gloo::reduce(opts); gloo::reduce(opts);
} }
void GlooCommContext::Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
uint32_t tag) {
gloo::GatherOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
opts.setTag(tag);
opts.setRoot(src);
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
if (rank_ == src) {
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
}
gloo::gather(opts);
}
void GlooCommContext::Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
int size,
uint32_t tag) {
gloo::ScatterOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
if (rank_ == src) {
GENERATE_FUNC(dtype, SetInputForScatter, &opts, in_tensor, size);
}
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
opts.setRoot(src);
opts.setTag(tag);
gloo::scatter(opts);
}
void GlooCommContext::Barrier() {
gloo::BarrierOptions opts(gloo_context_);
gloo::barrier(opts);
}
void GlooCommContext::Send(const phi::DenseTensor& in_tensor,
int dst,
uint32_t tag) {
SendRecvOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
opts.setSrc(gloo_context_.get()->rank);
opts.setDst(dst);
opts.setTag(tag);
send_recv(&opts);
}
void GlooCommContext::Recv(phi::DenseTensor* out_tensor,
int src,
uint32_t tag) {
SendRecvOptions opts(gloo_context_);
const auto& dtype = out_tensor->dtype();
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
opts.setSrc(src);
opts.setDst(gloo_context_.get()->rank);
opts.setTag(tag);
send_recv(&opts);
}
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -35,17 +35,38 @@ class GlooCommContext final : public CommContext { ...@@ -35,17 +35,38 @@ class GlooCommContext final : public CommContext {
void Broadcast(phi::DenseTensor* out_tensor, void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int root); int root,
uint32_t tag = 0);
void AllReduce(phi::DenseTensor* out_tensor, void AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int reduce_type); int reduce_type,
uint32_t tag = 0);
void Reduce(phi::DenseTensor* out_tensor, void Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int reduce_type, int reduce_type,
int root); int root,
uint32_t tag = 0);
void AllGather(phi::DenseTensor* out_tensor, void AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor); const phi::DenseTensor& in_tensor,
uint32_t tag = 0);
void Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
uint32_t tag = 0);
void Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
int size,
uint32_t tag = 0);
void Barrier();
void Send(const phi::DenseTensor& in_tensor, int dst, uint32_t tag = 0);
void Recv(phi::DenseTensor* out_tensor, int src, uint32_t tag = 0);
private: private:
DISABLE_COPY_AND_ASSIGN(GlooCommContext); DISABLE_COPY_AND_ASSIGN(GlooCommContext);
......
...@@ -91,5 +91,20 @@ std::shared_ptr<gloo::transport::Device> CreateGlooDevice() { ...@@ -91,5 +91,20 @@ std::shared_ptr<gloo::transport::Device> CreateGlooDevice() {
} }
} }
void send_recv(SendRecvOptions* opts) {
const auto& context = opts->context;
gloo::transport::UnboundBuffer* in = opts->in.get();
gloo::transport::UnboundBuffer* out = opts->out.get();
const auto slot = gloo::Slot::build(kSendRecvSlotPrefix, opts->tag);
if (context->rank == opts->src) {
in->send(opts->dst, slot);
in->waitSend(opts->timeout);
} else if (context->rank == opts->dst) {
out->recv(opts->src, slot);
out->waitRecv(opts->timeout);
}
}
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <gloo/allreduce.h>
#include <gloo/math.h> #include <gloo/math.h>
#include <gloo/transport/tcp/device.h> #include <gloo/transport/tcp/device.h>
#include <gloo/types.h> #include <gloo/types.h>
...@@ -103,6 +104,19 @@ void SetInput(P* opts, const phi::DenseTensor& tensor) { ...@@ -103,6 +104,19 @@ void SetInput(P* opts, const phi::DenseTensor& tensor) {
tensor.numel()); tensor.numel());
} }
template <typename T, typename P>
void SetInputForScatter(P* opts, const phi::DenseTensor& tensor, int nranks) {
std::vector<T*> ret;
ret.reserve(nranks);
T* raw_pointer = reinterpret_cast<T*>(const_cast<void*>(tensor.data()));
size_t offset = 0;
for (int i = 0; i < nranks; i++) {
ret.push_back(raw_pointer + offset);
offset += tensor.numel() / nranks;
}
opts->setInputs(ret, tensor.numel() / nranks);
}
template <typename T, typename P> template <typename T, typename P>
void SetReduceFunc(P* opts, int reduce_type) { void SetReduceFunc(P* opts, int reduce_type) {
// gloo only support mutable data input // gloo only support mutable data input
...@@ -136,5 +150,55 @@ void SetReduceFunc(P* opts, int reduce_type) { ...@@ -136,5 +150,55 @@ void SetReduceFunc(P* opts, int reduce_type) {
// env preparation // env preparation
std::shared_ptr<gloo::transport::Device> CreateGlooDevice(); std::shared_ptr<gloo::transport::Device> CreateGlooDevice();
constexpr uint8_t kSendRecvSlotPrefix = 0x08;
class SendRecvOptions {
public:
explicit SendRecvOptions(const std::shared_ptr<gloo::Context>& context)
: context(context), timeout(context->getTimeout()) {}
template <typename T>
void setInput(T* ptr, size_t elements) {
this->in = context->createUnboundBuffer(ptr, elements * sizeof(T));
}
template <typename T>
void setOutput(T* ptr, size_t elements) {
this->out = context->createUnboundBuffer(ptr, elements * sizeof(T));
}
void setSrc(int src) { this->src = src; }
void setDst(int dst) { this->dst = dst; }
void setTag(uint32_t tag) { this->tag = tag; }
void setTimeout(std::chrono::milliseconds timeout) {
this->timeout = timeout;
}
protected:
std::shared_ptr<gloo::Context> context;
std::unique_ptr<gloo::transport::UnboundBuffer> in;
std::unique_ptr<gloo::transport::UnboundBuffer> out;
// Rank of process to send_recv from.
int src = -1;
// Rank of process to send_recv to.
int dst = -1;
// Tag for this operation.
// Must be unique across operations executing in parallel.
uint32_t tag = 0;
// End-to-end timeout for this operation.
std::chrono::milliseconds timeout;
friend void send_recv(SendRecvOptions*);
};
void send_recv(SendRecvOptions* opts);
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
import test_collective_api_base as test_base import legacy_test.test_collective_api_base as test_base
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
......
...@@ -12,7 +12,10 @@ ...@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main from legacy_test.test_collective_api_base import (
TestCollectiveAPIRunnerBase,
runtime_main,
)
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册