未验证 提交 1e56ca8a 编写于 作者: L lilong12 提交者: GitHub

Use densetensor instead of Tensor for ProcessGroup (#41403)

上级 1cdd88f6
......@@ -17,11 +17,11 @@
namespace paddle {
namespace distributed {
std::vector<Place> GetPlaceList(const std::vector<Tensor>& tensors) {
std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors) {
std::vector<Place> places;
places.reserve(tensors.size());
for (auto& tensor : tensors) {
places.push_back(tensor.inner_place());
places.push_back(tensor.place());
}
return places;
}
......@@ -40,15 +40,11 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
return placeList;
}
static bool CheckTensorsInPlace(const std::vector<Tensor>& tensors,
phi::AllocationType type) {
return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) {
return t.place().GetType() == type;
});
}
bool CheckTensorsInCudaPlace(const std::vector<Tensor>& tensors) {
return CheckTensorsInPlace(tensors, phi::AllocationType::GPU);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
return std::all_of(tensors.cbegin(), tensors.cend(),
[&](const phi::DenseTensor& t) {
return platform::is_gpu_place(t.place());
});
}
} // namespace distributed
......
......@@ -16,18 +16,18 @@
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace distributed {
using Tensor = paddle::experimental::Tensor;
using Place = paddle::platform::Place;
// Get the list of devices from list of tensors
std::vector<Place> GetPlaceList(const std::vector<Tensor>& tensors);
std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors);
// Get the deviceList String from the list of devices
std::string GetKeyFromPlaces(const std::vector<Place>& places);
bool CheckTensorsInCudaPlace(const std::vector<Tensor>& tensors);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors);
} // namespace distributed
} // namespace paddle
......@@ -17,7 +17,8 @@
namespace paddle {
namespace distributed {
ProcessGroup::Task::Task(int rank, const std::vector<Tensor>& inputTensors,
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputTensors,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
......
......@@ -54,7 +54,7 @@ class ProcessGroup {
public:
class Task {
public:
Task(int rank, const std::vector<Tensor>& inputTensors,
Task(int rank, const std::vector<phi::DenseTensor>& inputTensors,
CommType opType = CommType::UNKNOWN);
virtual ~Task();
......@@ -79,25 +79,21 @@ class ProcessGroup {
virtual const std::string GetBackendName() const = 0;
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& /* tensors */,
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const AllreduceOptions& = AllreduceOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& /* tensors */,
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const BroadcastOptions& = BroadcastOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast", GetBackendName()));
}
virtual void Broadcast(const phi::DenseTensor* in, phi::DenseTensor* out) {
PADDLE_THROW(platform::errors::Fatal(
"ProcessGroup%s does not support broadcast for static mode runtime",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -105,42 +101,43 @@ class ProcessGroup {
}
virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<Tensor>& tensors /* tensors */, int dst_rank) { // NOLINT
std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<Tensor>& tensors /* tensors */, int src_rank) { // NOLINT
std::vector<phi::DenseTensor>& tensors, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<Tensor>& in_tensors /* tensors */, // NOLINT
std::vector<Tensor>& out_tensors /* tensors */) { // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<Tensor>& in /* tensors */, // NOLINT
std::vector<Tensor>& out /* tensors */) { // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllToAll", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<Tensor>& tensors /* tensors */, // NOLINT
const ReduceOptions& opts) { // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ReduceOptions& opts) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Reduce", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<Tensor>& in_tensors /* tensors */, // NOLINT
std::vector<Tensor>& out_tensors /* tensors */, // NOLINT
const ScatterOptions&) { // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ScatterOptions&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Scatter", GetBackendName()));
}
......
......@@ -27,6 +27,7 @@
#include <gloo/broadcast.h>
#include <gloo/reduce.h>
#include <gloo/scatter.h>
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -105,107 +106,104 @@ reduce_func get_function(const ReduceOp& r) {
exit(-1);
}
bool CheckTensorsInCPUPlace(const std::vector<Tensor>& tensors) {
return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) {
return t.place() == PlaceType::kCPU;
});
}
template <typename T>
T* get_data(const Tensor& tensor) {
auto raw_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
return static_cast<T*>(raw_tensor->data());
T* get_data(phi::DenseTensor& tensor) { // NOLINT
return reinterpret_cast<T*>(tensor.data());
}
template <typename T>
std::vector<T*> get_multi_data(const std::vector<Tensor>& tensors) {
std::vector<T*> ret(tensors.size());
std::vector<T*> get_multi_data(
std::vector<phi::DenseTensor>& tensors) { // NOLINT
std::vector<T*> ret;
ret.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); i++) {
ret[i] = get_data<T>(tensors[i]);
ret.push_back(get_data<T>(tensors[i]));
}
return ret;
}
template <typename T, typename P>
void set_output(P& opts, const Tensor& tensor) { // NOLINT
void set_output(P& opts, phi::DenseTensor& tensor) { // NOLINT
opts.setOutput(get_data<T>(tensor), tensor.numel());
}
template <typename T, typename P>
void set_input(P& opts, const Tensor& tensor) { // NOLINT
void set_input(P& opts, phi::DenseTensor& tensor) { // NOLINT
opts.setInput(get_data<T>(tensor), tensor.numel());
}
template <typename T, typename P>
void set_outputs(P& opts, const std::vector<Tensor>& tensors) { // NOLINT
void set_outputs(P& opts, // NOLINT
std::vector<phi::DenseTensor>& tensors) { // NOLINT
opts.setOutputs(get_multi_data<T>(tensors), tensors[0].numel());
}
template <typename T, typename P>
void set_inputs(P& opts, const std::vector<Tensor>& tensors) { // NOLINT
void set_inputs(P& opts, // NOLINT
std::vector<phi::DenseTensor>& tensors) { // NOLINT
opts.setInputs(get_multi_data<T>(tensors), tensors[0].numel());
}
template <typename T, typename P>
void set_inputs_for_scatter(P& opts, // NOLINT
const std::vector<Tensor>& tensors, // NOLINT
void set_inputs_for_scatter(P& opts, // NOLINT
phi::DenseTensor& tensor, // NOLINT
int nranks) {
std::vector<T*> ret(nranks);
auto raw_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[0].impl());
T* raw_pointer = reinterpret_cast<T*>(raw_tensor->data());
std::vector<T*> ret;
ret.reserve(nranks);
T* raw_pointer = reinterpret_cast<T*>(tensor.data());
size_t offset = 0;
for (int i = 0; i < nranks; i++) {
ret[i] = raw_pointer + offset;
offset += tensors[0].numel() / nranks;
ret.push_back(raw_pointer + offset);
offset += tensor.numel() / nranks;
}
opts.setInputs(ret, tensors[0].numel() / nranks);
opts.setInputs(ret, tensor.numel() / nranks);
}
ProcessGroupGloo::GlooTask::GlooTask(int rank,
const std::vector<Tensor>& inputs,
CommType comm_type)
: ProcessGroup::Task(rank, inputs, comm_type) {
PADDLE_ENFORCE_EQ(CheckTensorsInCPUPlace(inputs), true,
platform::errors::Fatal(
"Only CPU place is supported for ProcessGroupGloo."));
}
ProcessGroupGloo::GlooTask::GlooTask(
int rank, const std::vector<phi::DenseTensor>& inputs, CommType comm_type)
: ProcessGroup::Task(rank, inputs, comm_type) {}
ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, int rank,
int world_size, int gid, const std::shared_ptr<GlooOptions> options)
const std::shared_ptr<distributed::Store>& store, int rank, int world_size,
int gid, const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size, gid),
_tag(0),
_store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
auto prefix_store =
::gloo::rendezvous::PrefixStore(std::to_string(0), *_store);
::gloo::rendezvous::PrefixStore(std::to_string(gid), *_store);
_context->connectFullMesh(prefix_store, options->device);
}
class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
public:
BroadcastGlooTask(const std::shared_ptr<gloo::Context>& context,
const std::vector<Tensor>& inputs, int rank, int root,
uint32_t tag)
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
int rank, int root, uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST),
_context(context),
_root(root),
_inputs(inputs),
_outputs(outputs),
_tag(tag) {}
void Run() override { _do_broadcast(_inputs[0]); }
void Run() override { _do_broadcast(_inputs[0], _outputs[0]); }
private:
std::shared_ptr<gloo::Context> _context;
const int _root;
std::vector<Tensor> _inputs{};
std::vector<phi::DenseTensor> _inputs{};
std::vector<phi::DenseTensor> _outputs{};
const uint32_t _tag;
void _do_broadcast(const Tensor& tensor) {
void _do_broadcast(phi::DenseTensor& in, phi::DenseTensor& out) { // NOLINT
gloo::BroadcastOptions opts(_context);
const auto& dtype = tensor.type();
GENERATE_FUNC(dtype, set_output, opts, tensor);
const auto& dtype = in.dtype();
if (rank_ == _root) {
GENERATE_FUNC(dtype, set_input, opts, in);
}
GENERATE_FUNC(dtype, set_output, opts, out);
opts.setRoot(_root);
opts.setTag(_tag);
gloo::broadcast(opts);
......@@ -213,12 +211,14 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
std::vector<Tensor>& inputs, const BroadcastOptions& opts) {
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, const BroadcastOptions& opts) {
auto root = opts.source_rank;
std::unique_ptr<BroadcastGlooTask> task;
auto tag = next_tag();
auto context = get_context();
task = std::make_unique<BroadcastGlooTask>(context, inputs, rank_, root, tag);
task = std::make_unique<BroadcastGlooTask>(context, inputs, outputs, rank_,
root, tag);
task->Run();
return task;
}
......@@ -226,19 +226,22 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
public:
AllreduceGlooTask(int rank, const std::shared_ptr<gloo::Context>& context,
std::vector<Tensor>& inputs, ReduceOp reduce_op, // NOLINT
uint32_t tag)
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op, uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE),
_context(context),
_inputs(inputs),
_outputs(outputs),
_reduce_op(reduce_op),
_tag(tag) {}
void Run() override { _do_allreduce(_inputs); }
void Run() override { _do_allreduce(_inputs, _outputs); }
private:
std::shared_ptr<gloo::Context> _context;
std::vector<Tensor> _inputs;
std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs;
const ReduceOp _reduce_op;
uint32_t _tag;
......@@ -255,11 +258,12 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
fn = get_function<T>(op);
}
void _do_allreduce(std::vector<Tensor>& tensors) { // NOLINT
const auto& dtype = tensors[0].type();
void _do_allreduce(std::vector<phi::DenseTensor>& ins, // NOLINT
std::vector<phi::DenseTensor>& outs) { // NOLINT
const auto& dtype = ins[0].dtype();
gloo::AllreduceOptions opts(_context);
GENERATE_FUNC(dtype, set_inputs, opts, tensors);
GENERATE_FUNC(dtype, set_outputs, opts, tensors);
GENERATE_FUNC(dtype, set_inputs, opts, ins);
GENERATE_FUNC(dtype, set_outputs, opts, outs);
opts.setReduceFunction(_get_function(dtype, _reduce_op));
opts.setTag(_tag);
gloo::allreduce(opts);
......@@ -267,11 +271,12 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<Tensor>& inputs, const AllreduceOptions& opts) {
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, const AllreduceOptions& opts) {
auto tag = next_tag();
std::shared_ptr<GlooTask> task;
auto context = get_context();
task = std::make_shared<AllreduceGlooTask>(rank_, context, inputs,
task = std::make_shared<AllreduceGlooTask>(rank_, context, inputs, outputs,
opts.reduce_op, tag);
task->Run();
return task;
......@@ -280,7 +285,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
class BarrierGlooTask : public ProcessGroupGloo::GlooTask {
public:
BarrierGlooTask(int rank, const std::shared_ptr<gloo::Context>& context)
: ProcessGroupGloo::GlooTask(rank, std::vector<Tensor>{},
: ProcessGroupGloo::GlooTask(rank, std::vector<phi::DenseTensor>{},
CommType::BARRIER),
_context(context) {}
......@@ -307,8 +312,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier(
class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
public:
AllgatherGlooTask(int rank, const std::shared_ptr<gloo::Context>& context,
std::vector<Tensor>& inputs, // NOLINT
std::vector<Tensor>& outputs, // NOLINT
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLGATHER),
_context(context),
......@@ -320,13 +325,13 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
private:
std::shared_ptr<gloo::Context> _context;
std::vector<Tensor> _inputs;
std::vector<Tensor> _outputs;
std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs;
uint32_t _tag;
void _do_allgather(std::vector<Tensor>& in, // NOLINT
std::vector<Tensor>& out) { // NOLINT
const auto& dtype = in[0].type();
void _do_allgather(std::vector<phi::DenseTensor>& in, // NOLINT
std::vector<phi::DenseTensor>& out) { // NOLINT
const auto& dtype = in[0].dtype();
gloo::AllgatherOptions opts(_context);
GENERATE_FUNC(dtype, set_input, opts, in[0]);
GENERATE_FUNC(dtype, set_output, opts, out[0]);
......@@ -336,7 +341,8 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
std::shared_ptr<AllgatherGlooTask> task;
auto tag = next_tag();
auto context = get_context();
......@@ -349,20 +355,23 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
public:
ReduceGlooTask(int rank, const std::shared_ptr<gloo::Context>& context,
std::vector<Tensor>& in, ReduceOp reduce_op, // NOLINT
int dst, uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, in, CommType::REDUCE),
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op, int dst, uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::REDUCE),
_context(context),
_inputs(in),
_inputs(inputs),
_outputs(outputs),
_reduce_op(reduce_op),
_dst(dst),
_tag(tag) {}
void Run() override { _do_reduce(_inputs, _dst); }
void Run() override { _do_reduce(_inputs, _outputs, _dst); }
private:
std::shared_ptr<gloo::Context> _context;
std::vector<Tensor> _inputs;
std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs;
const ReduceOp _reduce_op;
int _dst;
uint32_t _tag;
......@@ -380,11 +389,13 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
fn = get_function<T>(op);
}
void _do_reduce(std::vector<Tensor>& tensors, int dst) { // NOLINT
const auto& dtype = tensors[0].type();
void _do_reduce(std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
int dst) {
const auto& dtype = inputs[0].dtype();
gloo::ReduceOptions opts(_context);
GENERATE_FUNC(dtype, set_input, opts, tensors[0]);
GENERATE_FUNC(dtype, set_output, opts, tensors[0]);
GENERATE_FUNC(dtype, set_input, opts, inputs[0]);
GENERATE_FUNC(dtype, set_output, opts, outputs[0]);
opts.setReduceFunction(_get_function(dtype, _reduce_op));
opts.setTag(_tag);
opts.setRoot(dst);
......@@ -393,11 +404,12 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
std::vector<Tensor>& tensors, const ReduceOptions& opts) {
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, const ReduceOptions& opts) {
std::shared_ptr<ReduceGlooTask> task;
auto tag = next_tag();
auto context = get_context();
task = std::make_shared<ReduceGlooTask>(rank_, context, tensors,
task = std::make_shared<ReduceGlooTask>(rank_, context, inputs, outputs,
opts.reduce_op, opts.root_rank, tag);
task->Run();
return task;
......@@ -406,8 +418,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
public:
ScatterGlooTask(int rank, const std::shared_ptr<gloo::Context>& context,
std::vector<Tensor>& inputs, // NOLINT
std::vector<Tensor>& outputs, // NOLINT
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
int src, int size, uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::SCATTER),
_context(context),
......@@ -421,18 +433,19 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
private:
std::shared_ptr<gloo::Context> _context;
std::vector<Tensor> _inputs;
std::vector<Tensor> _outputs;
std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs;
int _src;
int _size;
uint32_t _tag;
void _do_scatter(std::vector<Tensor>& in, std::vector<Tensor>& out, // NOLINT
void _do_scatter(std::vector<phi::DenseTensor>& in, // NOLINT
std::vector<phi::DenseTensor>& out, // NOLINT
int src) {
const auto& dtype = in[0].type();
const auto& dtype = in[0].dtype();
gloo::ScatterOptions opts(_context);
if (rank_ == src) {
GENERATE_FUNC(dtype, set_inputs_for_scatter, opts, in, _size);
GENERATE_FUNC(dtype, set_inputs_for_scatter, opts, in[0], _size);
}
GENERATE_FUNC(dtype, set_output, opts, out[0]);
opts.setRoot(src);
......@@ -442,8 +455,8 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors,
const ScatterOptions& opts) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const ScatterOptions& opts) {
std::shared_ptr<ScatterGlooTask> task;
auto tag = next_tag();
auto context = get_context();
......
......@@ -36,7 +36,8 @@ class ProcessGroupGloo : public ProcessGroup {
class GlooTask : public ProcessGroup::Task,
public std::enable_shared_from_this<GlooTask> {
public:
explicit GlooTask(int rank, const std::vector<Tensor>& input_tensors,
explicit GlooTask(int rank,
const std::vector<phi::DenseTensor>& input_tensors,
CommType comm_type);
~GlooTask() = default;
......@@ -106,26 +107,31 @@ class ProcessGroupGloo : public ProcessGroup {
~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& inputs,
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& inputs,
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors) override;
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<Tensor>& tensors, const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors,
const ScatterOptions&) override;
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions&) override;
std::shared_ptr<::gloo::Context> get_context() { return _context; }
uint64_t next_tag() { return _tag++; }
......
......@@ -44,14 +44,14 @@ void SyncDefaultStream(
std::shared_ptr<ProcessGroupHCCL::HCCLTask> ProcessGroupHCCL::CreateTask(
std::vector<Place> places, int rank, CommType comm_type,
const std::vector<Tensor>& inputs) {
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupHCCL::HCCLTask>(places, rank, comm_type,
inputs);
}
ProcessGroupHCCL::HCCLTask::HCCLTask(const std::vector<Place>& places, int rank,
CommType CommType,
const std::vector<Tensor>& inputs)
ProcessGroupHCCL::HCCLTask::HCCLTask(
const std::vector<Place>& places, int rank, CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size());
hcclComms_.resize(places.size());
......@@ -60,8 +60,8 @@ ProcessGroupHCCL::HCCLTask::HCCLTask(const std::vector<Place>& places, int rank,
ProcessGroupHCCL::HCCLTask::~HCCLTask() {}
void ProcessGroupHCCL::HCCLTask::SetOutputs(
std::vector<Tensor>& outputs) { // NOLINT
outputs_ = std::make_shared<std::vector<Tensor>>(outputs);
std::vector<phi::DenseTensor>& outputs) { // NOLINT
outputs_ = std::make_shared<std::vector<phi::DenseTensor>>(outputs);
}
void ProcessGroupHCCL::HCCLTask::SynchronizeStreams() {
......@@ -166,8 +166,8 @@ void ProcessGroupHCCL::CreateHCCLManagerCache(
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective(
std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, Fn fn,
CommType op_type) {
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, Fn fn, CommType op_type) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
......@@ -208,91 +208,44 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective(
return task;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::PointToPoint(
std::vector<Tensor>& tensors, Fn fn, int dst_rank, CommType op_type) {
const auto places = GetPlaceList(tensors);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_hcclcomm_.find(key) == places_to_hcclcomm_.end()) {
CreateHCCLManagerCache(key, places);
}
}
auto& hccl_comms = places_to_hcclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, tensors);
// construct uninitialize guard for device
// if (FLAGS_use_stream_safe_npu_allocator) {
// for (size_t i = 0; i < tensors.size(); ++i) {
// platform::NPUDeviceGuard guard(places[i].GetDeviceId());
// auto dense_tensor =
// std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
// memory::RecordStream(dense_tensor->Holder(),
// places_to_ctx_[key][i]->stream());
// }
// }
for (size_t i = 0; i < tensors.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId());
const auto& hccl_stream = places_to_ctx_[key][i]->stream();
fn(tensors[i], hccl_comms[i]->GetHcclComm(), hccl_stream, dst_rank);
}
for (size_t i = 0; i < tensors.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId());
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::AllReduce(
std::vector<Tensor>& tensors, const AllreduceOptions& opts) {
// PADDLE_ENFORCE_EQ(
// CheckTensorsInNPUPlace(tensors), true,
// platform::errors::InvalidArgument("All inputs should be in
// NPUPlace."));
return Collective(
tensors, tensors,
[&](const Tensor& input, Tensor& output, HcclComm comm,
const aclrtStream& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::HcclAllReduce(
input_tensor->data(), output_tensor->data(), input_tensor->numel(),
platform::ToHCCLDataType(input.type()),
ToHCCLRedType(opts.reduce_op), comm, stream);
},
CommType::ALLREDUCE);
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const AllreduceOptions& opts) {
return Collective(in_tensors, out_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output,
HcclComm comm, const aclrtStream& stream) {
return platform::dynload::HcclAllReduce(
input.data(), output.data(), input.numel(),
platform::ToHCCLDataType(input.dtype()),
ToHCCLRedType(opts.reduce_op), comm, stream);
},
CommType::ALLREDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Broadcast(
std::vector<Tensor>& tensors, const BroadcastOptions& opts) {
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts) {
// PADDLE_ENFORCE_EQ(
// CheckTensorsInNPUPlace(tensors), true,
// platform::errors::InvalidArgument("All inputs should be in
// CudaPlace."));
return Collective(
tensors, tensors,
[&](Tensor& input, Tensor& output, HcclComm comm,
in_tensors, out_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output, HcclComm comm,
const aclrtStream& stream) {
const auto root = opts.source_rank * tensors.size() + opts.source_root;
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::HcclBroadcast(
input_tensor->data(), input_tensor->numel(),
platform::ToHCCLDataType(input.type()), root, comm, stream);
int root = opts.source_rank * in_tensors.size() + opts.source_root;
if (rank_ == root) {
return platform::dynload::HcclBroadcast(
input.data(), input.numel(),
platform::ToHCCLDataType(input.dtype()), root, comm, stream);
} else {
return platform::dynload::HcclBroadcast(
output.data(), output.numel(),
platform::ToHCCLDataType(output.dtype()), root, comm, stream);
}
},
CommType::BROADCAST);
}
......
......@@ -46,7 +46,7 @@ class ProcessGroupHCCL : public ProcessGroup {
public std::enable_shared_from_this<HCCLTask> {
public:
HCCLTask(const std::vector<Place>& places, int rank, CommType CommType,
const std::vector<Tensor>& inputs);
const std::vector<phi::DenseTensor>& inputs);
bool IsCompleted();
......@@ -56,7 +56,7 @@ class ProcessGroupHCCL : public ProcessGroup {
void Synchronize();
void SetOutputs(std::vector<Tensor>& outputs); // NOLINT
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~HCCLTask();
......@@ -65,7 +65,7 @@ class ProcessGroupHCCL : public ProcessGroup {
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<HCCLCommManager>> hcclComms_;
std::shared_ptr<std::vector<Tensor>> outputs_;
std::shared_ptr<std::vector<phi::DenseTensor>> outputs_;
private:
};
......@@ -78,17 +78,19 @@ class ProcessGroupHCCL : public ProcessGroup {
}
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& tensors,
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& tensors,
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
protected:
virtual std::shared_ptr<ProcessGroupHCCL::HCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
const std::vector<Tensor>& inputs);
const std::vector<phi::DenseTensor>& inputs);
std::shared_ptr<Store> store_;
std::shared_ptr<HCCLCommManager> hccl_comm_;
......@@ -113,15 +115,10 @@ class ProcessGroupHCCL : public ProcessGroup {
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<Tensor>& inputs, // NOLINT
std::vector<Tensor>& outputs, // NOLINT
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn, CommType op_type);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<Tensor>& tensors, // NOLINT
Fn fn, int dst_rank, CommType op_type);
void CreateHCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
};
......
......@@ -26,13 +26,13 @@ namespace distributed {
using Place = paddle::platform::Place;
std::shared_ptr<ProcessGroupHeter::HeterTask> ProcessGroupHeter::CreateTask(
int rank, CommType comm_type, const std::vector<Tensor>& inputs) {
int rank, CommType comm_type, const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupHeter::HeterTask>(rank, comm_type,
inputs);
}
ProcessGroupHeter::HeterTask::HeterTask(int rank, CommType CommType,
const std::vector<Tensor>& inputs)
ProcessGroupHeter::HeterTask::HeterTask(
int rank, CommType CommType, const std::vector<phi::DenseTensor>& inputs)
: Task(rank, inputs, CommType) {}
ProcessGroupHeter::HeterTask::~HeterTask() {}
......@@ -86,248 +86,177 @@ static void _do_add(T* dst, T* src, size_t size) {
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::vector<Tensor>& tensors, const AllreduceOptions& opts) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const AllreduceOptions& opts) {
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(tensors), true,
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
#endif
// Step1: do allreduce in inner cluster
auto task = inner_pg_->AllReduce(tensors, opts);
auto task = inner_pg_->AllReduce(in_tensors, in_tensors, opts);
task->Wait();
// Step2: copy tensors to CPU
if (local_rank_ == 0) {
std::vector<Tensor> cpu_tensors;
cpu_tensors.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); i++) {
auto dense_gpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
phi::DenseTensorMeta meta = phi::DenseTensorMeta(
dense_gpu_tensor->dtype(), dense_gpu_tensor->dims());
std::shared_ptr<phi::DenseTensor> dense_cpu_tensor =
std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
dense_cpu_tensor->ResizeAndAllocate(dense_gpu_tensor->dims());
cpu_tensors[i] = paddle::experimental::Tensor(dense_cpu_tensor);
framework::TensorCopySync(*dense_gpu_tensor, platform::CPUPlace(),
dense_cpu_tensor.get());
std::vector<phi::DenseTensor> cpu_tensors;
cpu_tensors.reserve(in_tensors.size());
for (size_t i = 0; i < in_tensors.size(); i++) {
auto gpu_tensor = in_tensors[i];
auto cpu_tensor = cpu_tensors[i];
cpu_tensor.Resize(gpu_tensor.dims());
framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor);
}
// Step3: do inter cluster allreduce
if (with_switch_) {
if (local_rank_ == 0) {
HeterClient* client_ =
HeterClient::GetInstance({switch_endpoint_}, {}, 0).get();
auto dense_cpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[0].impl());
auto dense_cpu_tensor = cpu_tensors[0];
std::vector<int> send_size;
send_size.push_back(dense_cpu_tensor->numel());
send_size.push_back(dense_cpu_tensor.numel());
int ret = client_->Send(
gid_, {dense_cpu_tensor->name()}, send_size,
dense_cpu_tensor->data(),
dense_cpu_tensor->numel() *
framework::DataTypeSize(dense_cpu_tensor->dtype()));
gid_, {dense_cpu_tensor.name()}, send_size, dense_cpu_tensor.data(),
dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Send to the switch module error."));
phi::DenseTensorMeta meta = phi::DenseTensorMeta(
dense_cpu_tensor->dtype(), dense_cpu_tensor->dims());
dense_cpu_tensor.dtype(), dense_cpu_tensor.dims());
std::shared_ptr<phi::DenseTensor> dense_cpu_tensor2 =
std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
dense_cpu_tensor2->ResizeAndAllocate(dense_cpu_tensor->dims());
Tensor cpu_tensor_temp =
paddle::experimental::Tensor(dense_cpu_tensor2);
dense_cpu_tensor2->ResizeAndAllocate(dense_cpu_tensor.dims());
ret = client_->Recv(
gid_, {dense_cpu_tensor->name()}, dense_cpu_tensor2->data(),
gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor2->data(),
dense_cpu_tensor2->numel() *
framework::DataTypeSize(dense_cpu_tensor2->dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Recv from the switch module error."));
switch (dense_cpu_tensor->dtype()) {
switch (dense_cpu_tensor.dtype()) {
case DataType::FLOAT32:
_do_add<float>(reinterpret_cast<float*>(dense_cpu_tensor->data()),
_do_add<float>(reinterpret_cast<float*>(dense_cpu_tensor.data()),
reinterpret_cast<float*>(dense_cpu_tensor2->data()),
dense_cpu_tensor->numel());
dense_cpu_tensor.numel());
break;
case DataType::FLOAT64:
_do_add<double>(
reinterpret_cast<double*>(dense_cpu_tensor->data()),
reinterpret_cast<double*>(dense_cpu_tensor.data()),
reinterpret_cast<double*>(dense_cpu_tensor2->data()),
dense_cpu_tensor->numel());
dense_cpu_tensor.numel());
break;
case DataType::INT32:
_do_add<int>(reinterpret_cast<int*>(dense_cpu_tensor->data()),
_do_add<int>(reinterpret_cast<int*>(dense_cpu_tensor.data()),
reinterpret_cast<int*>(dense_cpu_tensor2->data()),
dense_cpu_tensor->numel());
dense_cpu_tensor.numel());
break;
default:
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Unsupported data type (%s) to do add.",
framework::DataType2String(dense_cpu_tensor->dtype())));
framework::DataType2String(dense_cpu_tensor.dtype())));
}
}
} else {
auto gloo_task = inter_pg_->AllReduce(cpu_tensors, opts);
auto gloo_task = inter_pg_->AllReduce(cpu_tensors, cpu_tensors, opts);
gloo_task->Wait();
}
// Step4: copy cpu tensors to gpu
// copy cpu tensors to gpu
for (size_t i = 0; i < tensors.size(); i++) {
auto dense_gpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
auto dense_cpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[i].impl());
framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(),
dense_gpu_tensor.get());
for (size_t i = 0; i < in_tensors.size(); i++) {
auto gpu_tensor = out_tensors[i];
auto cpu_tensor = cpu_tensors[i];
framework::TensorCopySync(cpu_tensor, cpu_tensor.place(), &gpu_tensor);
}
}
// Step5: broadcast among inner cluster
auto b_opts = BroadcastOptions();
b_opts.source_root = 0;
auto broadcast_task = inner_pg_->Broadcast(tensors, b_opts);
b_opts.source_rank = 0;
auto broadcast_task = inner_pg_->Broadcast(out_tensors, out_tensors, b_opts);
broadcast_task->Wait();
return CreateTask(rank_, CommType::ALLREDUCE, tensors);
return CreateTask(rank_, CommType::ALLREDUCE, in_tensors);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
std::vector<Tensor>& tensors, const BroadcastOptions& opts) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const BroadcastOptions& opts) {
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(tensors), true,
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
#endif
// Step1: do broadcast in inner cluster
auto b_opts = BroadcastOptions();
b_opts.source_root = 0;
inner_pg_->Broadcast(tensors, b_opts);
b_opts.source_rank = 0;
inner_pg_->Broadcast(in_tensors, out_tensors, b_opts);
if (local_rank_ == 0) {
std::vector<Tensor> cpu_tensors;
cpu_tensors.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); i++) {
auto dense_gpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
phi::DenseTensorMeta meta = phi::DenseTensorMeta(
dense_gpu_tensor->dtype(), dense_gpu_tensor->dims());
std::shared_ptr<phi::DenseTensor> dense_cpu_tensor =
std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
dense_cpu_tensor->ResizeAndAllocate(dense_gpu_tensor->dims());
cpu_tensors[i] = paddle::experimental::Tensor(dense_cpu_tensor);
framework::TensorCopySync(*dense_gpu_tensor, platform::CPUPlace(),
dense_cpu_tensor.get());
std::vector<phi::DenseTensor> cpu_tensors;
cpu_tensors.reserve(in_tensors.size());
for (size_t i = 0; i < in_tensors.size(); i++) {
auto gpu_tensor = in_tensors[i];
auto cpu_tensor = cpu_tensors[i];
cpu_tensor.Resize(gpu_tensor.dims());
framework::TensorCopySync(gpu_tensor, platform::CPUPlace(), &cpu_tensor);
}
if (with_switch_) {
if (local_rank_ == 0) {
HeterClient* client_ =
HeterClient::GetInstance({switch_endpoint_}, {}, 0).get();
auto dense_cpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[0].impl());
auto dense_cpu_tensor = cpu_tensors[0];
if (gloo_rank_ == 0) {
std::vector<int> send_size;
send_size.push_back(dense_cpu_tensor->numel());
send_size.push_back(dense_cpu_tensor.numel());
int ret = client_->Send(
gid_, {dense_cpu_tensor->name()}, send_size,
dense_cpu_tensor->data(),
dense_cpu_tensor->numel() *
framework::DataTypeSize(dense_cpu_tensor->dtype()));
gid_, {dense_cpu_tensor.name()}, send_size,
dense_cpu_tensor.data(),
dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Send to the switch module error."));
} else {
int ret = client_->Recv(
gid_, {dense_cpu_tensor->name()}, dense_cpu_tensor->data(),
dense_cpu_tensor->numel() *
framework::DataTypeSize(dense_cpu_tensor->dtype()));
gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor.data(),
dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
ret = client_->Recv(
gid_, {dense_cpu_tensor->name()}, dense_cpu_tensor->data(),
dense_cpu_tensor->numel() *
framework::DataTypeSize(dense_cpu_tensor->dtype()));
gid_, {dense_cpu_tensor.name()}, dense_cpu_tensor.data(),
dense_cpu_tensor.numel() *
framework::DataTypeSize(dense_cpu_tensor.dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
}
}
} else {
auto gloo_task = inter_pg_->Broadcast(cpu_tensors, opts);
auto gloo_task = inter_pg_->Broadcast(cpu_tensors, cpu_tensors, opts);
gloo_task->Wait();
}
for (size_t i = 0; i < tensors.size(); i++) {
auto dense_gpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
auto dense_cpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[i].impl());
framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(),
dense_gpu_tensor.get());
for (size_t i = 0; i < in_tensors.size(); i++) {
auto gpu_tensor = out_tensors[i];
auto cpu_tensor = cpu_tensors[i];
framework::TensorCopySync(cpu_tensor, gpu_tensor.place(), &gpu_tensor);
}
}
auto broadcast_task = inner_pg_->Broadcast(tensors, b_opts);
auto broadcast_task = inner_pg_->Broadcast(out_tensors, out_tensors, b_opts);
broadcast_task->Wait();
return CreateTask(rank_, CommType::BROADCAST, tensors);
}
void ProcessGroupHeter::Broadcast(const phi::DenseTensor* in,
phi::DenseTensor* out) {
// Step1: do broadcast in inner cluster
inner_pg_->Broadcast(in, out);
if (local_rank_ == 0) {
phi::DenseTensorMeta meta = phi::DenseTensorMeta(in->dtype(), in->dims());
std::shared_ptr<phi::DenseTensor> dense_cpu_tensor =
std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
dense_cpu_tensor->ResizeAndAllocate(in->dims());
Tensor cpu_tensor = paddle::experimental::Tensor(dense_cpu_tensor);
framework::TensorCopySync(*in, platform::CPUPlace(),
dense_cpu_tensor.get());
if (with_switch_) {
if (local_rank_ == 0) {
HeterClient* client_ =
HeterClient::GetInstance({switch_endpoint_}, {}, 0).get();
if (gloo_rank_ == 0) {
std::vector<int> send_size;
send_size.push_back(in->numel());
int ret = client_->Send(
gid_, {in->name()}, send_size, dense_cpu_tensor->data(),
in->numel() * framework::DataTypeSize(in->dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Send to the switch module error."));
} else {
int ret =
client_->Recv(gid_, {in->name()}, dense_cpu_tensor->data(),
in->numel() * framework::DataTypeSize(in->dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
}
}
} else {
std::vector<Tensor> cpu_tensors = {cpu_tensor};
auto gloo_task = inter_pg_->Broadcast(cpu_tensors);
gloo_task->Wait();
}
framework::TensorCopySync(*dense_cpu_tensor, out->place(), out);
}
inner_pg_->Broadcast(out, out);
return CreateTask(rank_, CommType::BROADCAST, in_tensors);
}
} // namespace distributed
} // namespace paddle
} // namespace distributed
} // namespace paddle
......@@ -66,7 +66,8 @@ class ProcessGroupHeter : public ProcessGroup {
class HeterTask : public ProcessGroup::Task,
public std::enable_shared_from_this<HeterTask> {
public:
HeterTask(int rank, CommType CommType, const std::vector<Tensor>& inputs);
HeterTask(int rank, CommType CommType,
const std::vector<phi::DenseTensor>&);
bool IsCompleted();
......@@ -89,18 +90,16 @@ class ProcessGroupHeter : public ProcessGroup {
}
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& tensors,
std::vector<phi::DenseTensor>&, std::vector<phi::DenseTensor>&,
const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& tensors,
std::vector<phi::DenseTensor>&, std::vector<phi::DenseTensor>&,
const BroadcastOptions& = BroadcastOptions()) override;
void Broadcast(const phi::DenseTensor* in, phi::DenseTensor* out) override;
protected:
virtual std::shared_ptr<ProcessGroupHeter::HeterTask> CreateTask(
int rank, CommType opType, const std::vector<Tensor>& inputs);
int rank, CommType opType, const std::vector<phi::DenseTensor>& inputs);
private:
std::shared_ptr<Store> store_;
......
......@@ -41,14 +41,14 @@ void SyncDefaultStream(
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
std::vector<Place> places, int rank, CommType comm_type,
const std::vector<Tensor>& inputs) {
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>(places, rank, comm_type,
inputs);
}
ProcessGroupNCCL::NCCLTask::NCCLTask(const std::vector<Place>& places, int rank,
CommType CommType,
const std::vector<Tensor>& inputs)
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places, int rank, CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size());
ncclComms_.resize(places.size());
......@@ -57,8 +57,8 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(const std::vector<Place>& places, int rank,
ProcessGroupNCCL::NCCLTask::~NCCLTask() {}
void ProcessGroupNCCL::NCCLTask::SetOutputs(
std::vector<Tensor>& outputs) { // NOLINT
outputs_ = std::make_shared<std::vector<Tensor>>(outputs);
std::vector<phi::DenseTensor>& outputs) { // NOLINT
outputs_ = std::make_shared<std::vector<phi::DenseTensor>>(outputs);
}
void ProcessGroupNCCL::NCCLTask::SynchronizeStreams() {
......@@ -180,8 +180,8 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, Fn fn,
CommType op_type) {
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, Fn fn, CommType op_type) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
......@@ -205,9 +205,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(inputs[i].impl());
memory::RecordStream(dense_tensor->Holder(),
memory::RecordStream(inputs[i].Holder(),
places_to_ctx_[key][i]->stream());
}
}
......@@ -267,7 +265,8 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::vector<Tensor>& tensors, Fn fn, int dst_rank, CommType op_type) {
std::vector<phi::DenseTensor>& tensors, Fn fn, int dst_rank,
CommType op_type) {
const auto places = GetPlaceList(tensors);
const auto key = GetKeyFromPlaces(places);
......@@ -290,9 +289,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
memory::RecordStream(dense_tensor->Holder(),
memory::RecordStream(tensors[i].Holder(),
places_to_ctx_[key][i]->stream());
}
}
......@@ -314,46 +311,40 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
std::vector<Tensor>& tensors, const AllreduceOptions& opts) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(tensors), true,
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
tensors, tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclAllReduce(
input_tensor->data(), output_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op), comm, stream);
},
CommType::ALLREDUCE);
return Collective(in_tensors, out_tensors,
[&](const phi::DenseTensor& input, phi::DenseTensor& output,
ncclComm_t comm, const gpuStream_t& stream) {
return platform::dynload::ncclAllReduce(
input.data(), output.data(), input.numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op), comm, stream);
},
CommType::ALLREDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::vector<Tensor>& tensors, const BroadcastOptions& opts) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const BroadcastOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(tensors), true,
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
tensors, tensors,
[&](Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
const auto root = opts.source_rank * tensors.size() + opts.source_root;
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclBcast(
input_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()), root, comm, stream);
},
CommType::BROADCAST);
return Collective(in_tensors, out_tensors,
[&](phi::DenseTensor& input, phi::DenseTensor& output,
ncclComm_t comm, const gpuStream_t& stream) {
const auto root = opts.source_rank * in_tensors.size() +
opts.source_root;
return platform::dynload::ncclBroadcast(
input.data(), output.data(), input.numel(),
platform::ToNCCLDataType(input.type()), root, comm,
stream);
},
CommType::BROADCAST);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
......@@ -374,23 +365,24 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
places.emplace_back(place_id);
}
std::vector<Tensor> barrierTensors;
std::vector<phi::DenseTensor> barrierTensors;
barrierTensors.reserve(places.size());
platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId());
auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::GPUPlace());
barrierTensors.push_back(dt);
barrierTensors.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(dt.impl()));
}
auto task = ProcessGroupNCCL::AllReduce(barrierTensors);
auto task = ProcessGroupNCCL::AllReduce(barrierTensors, barrierTensors);
auto nccl_task = dynamic_cast<ProcessGroupNCCL::NCCLTask*>(task.get());
nccl_task->barrierTensors_ = std::move(barrierTensors);
return task;
}
void CheckTensorsInDifferentDevices(const std::vector<Tensor>& tensors,
const size_t num_devices) {
void CheckTensorsInDifferentDevices(
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
PADDLE_ENFORCE_EQ(
tensors.size() == 0, false,
platform::errors::InvalidArgument("Tensor list must be nonempty."));
......@@ -402,11 +394,11 @@ void CheckTensorsInDifferentDevices(const std::vector<Tensor>& tensors,
std::set<Place> used_devices;
for (const auto& t : tensors) {
PADDLE_ENFORCE_EQ(t.is_gpu() && t.is_dense_tensor(), true,
PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()), true,
platform::errors::InvalidArgument(
"Tensors must be CUDA and dense tensor."));
const auto inserted = used_devices.insert(t.inner_place()).second;
const auto inserted = used_devices.insert(t.place()).second;
PADDLE_ENFORCE_EQ(inserted, true,
platform::errors::InvalidArgument(
"Tensors must be on distinct GPU devices."));
......@@ -414,62 +406,55 @@ void CheckTensorsInDifferentDevices(const std::vector<Tensor>& tensors,
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
std::vector<Tensor>& tensors, int dst_rank) {
std::vector<phi::DenseTensor>& tensors, int dst_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](Tensor& input, ncclComm_t comm, const gpuStream_t& stream,
int dst_rank) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
return platform::dynload::ncclSend(
input_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()), dst_rank, comm, stream);
},
dst_rank, CommType::SEND);
auto task = PointToPoint(tensors,
[&](phi::DenseTensor& input, ncclComm_t comm,
const gpuStream_t& stream, int dst_rank) {
return platform::dynload::ncclSend(
input.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank, comm, stream);
},
dst_rank, CommType::SEND);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
std::vector<Tensor>& tensors, int src_rank) {
std::vector<phi::DenseTensor>& tensors, int src_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](Tensor& output, ncclComm_t comm, const gpuStream_t& stream,
int src_rank) {
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclRecv(
output_tensor->data(), output_tensor->numel(),
platform::ToNCCLDataType(output.type()), src_rank, comm, stream);
},
src_rank, CommType::RECV);
auto task = PointToPoint(tensors,
[&](phi::DenseTensor& output, ncclComm_t comm,
const gpuStream_t& stream, int src_rank) {
return platform::dynload::ncclRecv(
output.data(), output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank, comm, stream);
},
src_rank, CommType::RECV);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclAllGather(
input_tensor->data(), output_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()), comm, stream);
},
CommType::ALLGATHER);
return Collective(in_tensors, out_tensors,
[&](const phi::DenseTensor& input, phi::DenseTensor& output,
ncclComm_t comm, const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
input.data(), output.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()), comm,
stream);
},
CommType::ALLGATHER);
}
void* GetPointerByOffset(void* raw_pointer, size_t offset,
......@@ -493,10 +478,12 @@ void* GetPointerByOffset(void* raw_pointer, size_t offset,
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
}
return nullptr;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
......@@ -505,24 +492,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
[&](phi::DenseTensor& input, phi::DenseTensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
size_t offset = 0;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input_tensor->data(), offset, input.type()),
input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), i, comm, stream));
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_, platform::ToNCCLDataType(input.dtype()), i,
comm, stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output_tensor->data(), offset, input.type()),
input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), i, comm, stream));
offset += input_tensor->numel() / size_;
GetPointerByOffset(output.data(), offset, input.dtype()),
input.numel() / size_, platform::ToNCCLDataType(input.dtype()), i,
comm, stream));
offset += input.numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
},
......@@ -530,29 +513,26 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
std::vector<Tensor>& tensors, const ReduceOptions& opts) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const ReduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(tensors), true,
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
tensors, tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
in_tensors, out_tensors,
[&](const phi::DenseTensor& input, phi::DenseTensor& output,
ncclComm_t comm, const gpuStream_t& stream) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
input_tensor->data(), output_tensor->data(), input.numel(),
platform::ToNCCLDataType(input.type()),
input.data(), output.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op), opts.root_rank, comm, stream));
},
CommType::REDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors,
const ScatterOptions& opts) {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, const ScatterOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
......@@ -561,31 +541,27 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
[&](phi::DenseTensor& input, phi::DenseTensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
size_t offset = 0;
if (rank_ == opts.root_rank) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input_tensor->data(), offset, input.type()),
input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), i, comm, stream));
offset += input_tensor->numel() / size_;
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_, platform::ToNCCLDataType(input.dtype()),
i, comm, stream));
offset += input.numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output_tensor->data(), input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), opts.root_rank, comm,
output.data(), input.numel() / size_,
platform::ToNCCLDataType(input.dtype()), opts.root_rank, comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output_tensor->data(), input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), opts.root_rank, comm,
output.data(), input.numel() / size_,
platform::ToNCCLDataType(input.dtype()), opts.root_rank, comm,
stream));
}
},
......
......@@ -51,7 +51,7 @@ class ProcessGroupNCCL : public ProcessGroup {
public std::enable_shared_from_this<NCCLTask> {
public:
NCCLTask(const std::vector<Place>& places, int rank, CommType CommType,
const std::vector<Tensor>& inputs);
const std::vector<phi::DenseTensor>& inputs);
bool IsCompleted();
......@@ -61,17 +61,17 @@ class ProcessGroupNCCL : public ProcessGroup {
void Synchronize();
void SetOutputs(std::vector<Tensor>& outputs); // NOLINT
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~NCCLTask();
std::vector<EventManager> control_events_;
std::vector<Tensor> barrierTensors_;
std::vector<phi::DenseTensor> barrierTensors_;
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<NCCLCommManager>> ncclComms_;
std::shared_ptr<std::vector<Tensor>> outputs_;
std::shared_ptr<std::vector<phi::DenseTensor>> outputs_;
private:
};
......@@ -84,40 +84,46 @@ class ProcessGroupNCCL : public ProcessGroup {
}
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& tensors,
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& tensors,
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(std::vector<Tensor>& tensors,
int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(std::vector<Tensor>& tensors,
int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors) override;
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<Tensor>& in, std::vector<Tensor>& out) override;
std::vector<phi::DenseTensor>& in,
std::vector<phi::DenseTensor>& out) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<Tensor>& tensors, const ReduceOptions& opts) override;
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors,
const ScatterOptions&) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions&) override;
protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
const std::vector<Tensor>& inputs);
const std::vector<phi::DenseTensor>& inputs);
protected:
std::shared_ptr<Store> store_;
......@@ -142,8 +148,8 @@ class ProcessGroupNCCL : public ProcessGroup {
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<Tensor>& inputs, // NOLINT
std::vector<Tensor>& outputs, // NOLINT
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn, CommType op_type);
template <typename Fn>
......@@ -152,7 +158,7 @@ class ProcessGroupNCCL : public ProcessGroup {
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<Tensor>& tensors, // NOLINT
std::vector<phi::DenseTensor>& tensors, // NOLINT
Fn fn, int dst_rank, CommType op_type);
void CreateNCCLManagerCache(const std::string& places_key,
......
......@@ -734,7 +734,11 @@ void EagerReducer::ProcessUnusedDenseVars() {
distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
std::vector<Tensor> reduce_tensors = {global_used_vars_};
process_group_->AllReduce(reduce_tensors, opts)->Synchronize();
std::vector<phi::DenseTensor> in_out;
for (auto &t : reduce_tensors) {
in_out.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
process_group_->AllReduce(in_out, in_out, opts)->Synchronize();
framework::TensorToVector<int>(*global_used_tensor, *dev_ctx,
&local_used_vars_);
......@@ -820,7 +824,11 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
// all_reduce
std::vector<Tensor> reduce_tensors = {group->dense_contents_};
group->task = process_group_->AllReduce(reduce_tensors, opts);
std::vector<phi::DenseTensor> in_out;
for (auto &t : reduce_tensors) {
in_out.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
group->task = process_group_->AllReduce(in_out, in_out, opts);
// split in FinalizeBackward()
}
......@@ -871,7 +879,11 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
std::vector<Tensor> reduce_tensors = {rows_num_tensor};
process_group_->AllReduce(reduce_tensors, opts)->Synchronize();
std::vector<phi::DenseTensor> in_out;
for (auto &t : reduce_tensors) {
in_out.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
process_group_->AllReduce(in_out, in_out, opts)->Synchronize();
framework::TensorToVector<int64_t>(*rows_num_dense_tensor, *dev_ctx,
&rows_num_vector);
......@@ -908,8 +920,15 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
std::vector<Tensor> src_rows_tensors = {src_rows_tensor};
std::vector<Tensor> dst_rows_tensors = {dst_rows_tensor};
process_group_->AllGather(src_rows_tensors, dst_rows_tensors)
->Synchronize();
std::vector<phi::DenseTensor> in;
std::vector<phi::DenseTensor> out;
for (auto &t : src_rows_tensors) {
in.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
for (auto &t : dst_rows_tensors) {
out.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
process_group_->AllGather(in, out)->Synchronize();
framework::Vector<int64_t> dst_rows_vector(rows_num, 0);
auto *dst_rows_dense_tensor =
......@@ -934,8 +953,17 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
std::vector<Tensor> src_value_tensors = {src_value_tensor};
std::vector<Tensor> dst_value_tensors = {dst_value_tensor};
process_group_->AllGather(src_value_tensors, dst_value_tensors)
->Synchronize();
std::vector<phi::DenseTensor> src_dense;
std::vector<phi::DenseTensor> dst_dense;
for (auto &t : src_value_tensors) {
src_dense.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
for (auto &t : dst_value_tensors) {
dst_dense.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
process_group_->AllGather(src_dense, dst_dense)->Synchronize();
src->set_rows(dst_rows_vector);
*(src->mutable_value()) =
......
......@@ -18,7 +18,9 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace operators {
......@@ -35,6 +37,18 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
int nranks = ctx.Attr<int>("nranks");
int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*in);
out_tensor.push_back(*out);
auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait();
return;
}
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
......
......@@ -41,7 +41,12 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
pg->Broadcast(x, out);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*x);
out_tensor.push_back(*out);
auto task = pg->Broadcast(in_tensor, out_tensor);
task->Wait();
return;
}
......
......@@ -115,8 +115,10 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
opts.reduce_op = op;
std::vector<Tensor> tensors = {tensor};
return self.AllReduce(tensors, opts);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors, tensors, opts);
},
py::arg("tensor"), py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
......@@ -127,8 +129,10 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts;
opts.source_rank = source_rank;
std::vector<Tensor> tensors = {tensor};
return self.Broadcast(tensors, opts);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors, tensors, opts);
},
py::arg("tensor"), py::arg("source_rank"),
py::call_guard<py::gil_scoped_release>())
......@@ -146,7 +150,9 @@ void BindDistributed(py::module *m) {
[](distributed::ProcessGroup &self, py::handle py_tensor,
int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
std::vector<Tensor> tensors = {tensor};
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors, dst);
},
py::arg("tensor"), py::arg("dst"),
......@@ -156,7 +162,9 @@ void BindDistributed(py::module *m) {
[](distributed::ProcessGroup &self, py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
std::vector<Tensor> tensors = {tensor};
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors, src);
},
py::arg("tensor"), py::arg("src"),
......@@ -167,8 +175,12 @@ void BindDistributed(py::module *m) {
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
std::vector<Tensor> in_tensors = {in_tensor};
std::vector<Tensor> out_tensors = {out_tensor};
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllGather(in_tensors, out_tensors);
},
py::arg("in"), py::arg("out"),
......@@ -179,8 +191,12 @@ void BindDistributed(py::module *m) {
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
std::vector<Tensor> in_tensors = {in_tensor};
std::vector<Tensor> out_tensors = {out_tensor};
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll(in_tensors, out_tensors);
},
py::arg("in"), py::arg("out"),
......@@ -193,8 +209,10 @@ void BindDistributed(py::module *m) {
distributed::ReduceOptions opts;
opts.reduce_op = op;
opts.root_rank = dst;
std::vector<Tensor> tensors = {in_tensor};
return self.Reduce(tensors, opts);
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors, tensors, opts);
},
py::arg("tensor"), py::arg("dst"),
py::arg("op") = distributed::ReduceOp::SUM,
......@@ -207,8 +225,12 @@ void BindDistributed(py::module *m) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ScatterOptions opts;
opts.root_rank = src;
std::vector<Tensor> in_tensors = {in_tensor};
std::vector<Tensor> out_tensors = {out_tensor};
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.Scatter(in_tensors, out_tensors, opts);
},
py::arg("in"), py::arg("out"), py::arg("src"),
......
......@@ -46,6 +46,11 @@ class TestProcessGroupFp32(unittest.TestCase):
group = paddle.distributed.collective.Group(-1, 2, 0, [-1, -2])
ret = paddle.distributed.barrier(group)
assert ret == None
paddle.enable_static()
in_tensor = paddle.empty((1, 2))
in_tensor2 = paddle.empty((1, 2))
paddle.distributed.broadcast(in_tensor, src=0)
paddle.distributed.all_gather([in_tensor, in_tensor2], in_tensor)
print("test ok\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册