未验证 提交 07741593 编写于 作者: K kuizhiqing 提交者: GitHub

new group (#31682)

* new group

* ci compatible fix

* assert nccl
上级 dbeb3ea4
...@@ -19,12 +19,11 @@ ...@@ -19,12 +19,11 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/bkcl_helper.h" #include "paddle/fluid/platform/bkcl_helper.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
...@@ -77,7 +76,7 @@ void BKCLParallelContext::Init() { ...@@ -77,7 +76,7 @@ void BKCLParallelContext::Init() {
bkcl_ids.resize(strategy_.nrings_); bkcl_ids.resize(strategy_.nrings_);
if (strategy_.local_rank_ == 0) { if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker // generate the unique bkclid on the root worker
for (size_t i = 0; i < bkcl_ids.size(); ++i) { for (size_t i = 0; i < bkcl_ids.size(); ++i) {
auto ret = bkcl_get_unique_id(&bkcl_ids[i]); auto ret = bkcl_get_unique_id(&bkcl_ids[i]);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret, PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret,
...@@ -99,6 +98,28 @@ void BKCLParallelContext::Init() { ...@@ -99,6 +98,28 @@ void BKCLParallelContext::Init() {
} }
} }
void BKCLParallelContext::InitWithRingID(int ring_id) {
std::vector<BKCLUniqueId> bkcl_ids;
bkcl_ids.resize(1);
if (strategy_.local_rank_ == 0) {
// generate the unique bkclid on the root worker
auto ret = bkcl_get_unique_id(&bkcl_ids[0]);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret,
platform::errors::PreconditionNotMet(
"BKCL get unique id failed [%d]", ret));
}
BcastBKCLId(bkcl_ids, 0);
int xpu_id = BOOST_GET_CONST(platform::XPUPlace, place_).device;
VLOG(0) << "init BKCL context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id
<< " ring id: " << ring_id;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateBKCLComm(
&bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id);
}
void BKCLParallelContext::AllReduceByStream(const framework::Variable &src, void BKCLParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst, framework::Variable *dst,
int ring_id, bool use_calc_stream) { int ring_id, bool use_calc_stream) {
......
...@@ -36,6 +36,8 @@ class BKCLParallelContext : public ParallelContext { ...@@ -36,6 +36,8 @@ class BKCLParallelContext : public ParallelContext {
void Init() override; void Init() override;
void InitWithRingID(int ring_id) override;
void AllReduceByStream(const framework::Variable& src, void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id, framework::Variable* dst, int ring_id,
bool use_calc_stream) override; bool use_calc_stream) override;
......
...@@ -79,6 +79,30 @@ void NCCLParallelContext::Init() { ...@@ -79,6 +79,30 @@ void NCCLParallelContext::Init() {
} }
} }
void NCCLParallelContext::InitWithRingID(int ring_id) {
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);
if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
platform::dynload::ncclGetUniqueId(&nccl_ids[0]);
}
BcastNCCLId(nccl_ids, 0);
int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_ids[0], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id);
compute_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
comm_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
}
void NCCLParallelContext::AllReduceByStream(const framework::Variable &src, void NCCLParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst, framework::Variable *dst,
int ring_id, bool use_calc_stream) { int ring_id, bool use_calc_stream) {
......
...@@ -53,6 +53,8 @@ class NCCLParallelContext : public ParallelContext { ...@@ -53,6 +53,8 @@ class NCCLParallelContext : public ParallelContext {
void Init() override; void Init() override;
void InitWithRingID(int ring_id) override;
void AllReduceByStream(const framework::Variable& src, void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id, framework::Variable* dst, int ring_id,
bool use_calc_stream) override; bool use_calc_stream) override;
......
...@@ -50,6 +50,8 @@ class ParallelContext { ...@@ -50,6 +50,8 @@ class ParallelContext {
virtual void Init() = 0; virtual void Init() = 0;
virtual void InitWithRingID(int ring_id) = 0;
virtual void AllReduceByStream(const framework::Variable& src, virtual void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id, framework::Variable* dst, int ring_id,
bool use_calc_stream) = 0; bool use_calc_stream) = 0;
......
...@@ -15,40 +15,20 @@ limitations under the License. */ ...@@ -15,40 +15,20 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CSyncCalcStreamOp : public framework::OperatorBase { class CSyncCalcStreamOp : public framework::OperatorWithKernel {
public: public:
CSyncCalcStreamOp(const std::string& type, using framework::OperatorWithKernel::OperatorWithKernel;
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs, void InferShape(framework::InferShapeContext* ctx) const override {}
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} protected:
framework::OpKernelType GetExpectedKernelType(
void RunImpl(const framework::Scope& scope, const framework::ExecutionContext& ctx) const override {
const platform::Place& place) const override { return framework::OpKernelType(framework::proto::VarType::FP32,
PADDLE_ENFORCE_EQ(is_gpu_place(place), true, ctx.GetPlace());
platform::errors::PreconditionNotMet(
"Sync stream op can run on gpu place only for now."));
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(dev_ctx->stream()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream()));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
} }
}; };
...@@ -65,10 +45,36 @@ Call calculation stream synchronization. ...@@ -65,10 +45,36 @@ Call calculation stream synchronization.
} }
}; };
template <typename T>
class CSyncCalcStreamCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
auto place = ctx.GetPlace();
auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(dev_ctx->stream()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream()));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(c_sync_calc_stream, ops::CSyncCalcStreamOp, REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker); ops::CSyncCalcStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamCudaKernel<float>);
...@@ -14,45 +14,25 @@ limitations under the License. */ ...@@ -14,45 +14,25 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CSyncCommStreamOp : public framework::OperatorBase { class CSyncCommStreamOp : public framework::OperatorWithKernel {
public: public:
CSyncCommStreamOp(const std::string& type, using framework::OperatorWithKernel::OperatorWithKernel;
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"Sync stream op can run on gpu place only for now."));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) void InferShape(framework::InferShapeContext* ctx) const override {}
int ring_id = Attr<int>("ring_id");
auto stream = protected:
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); framework::OpKernelType GetExpectedKernelType(
#ifdef PADDLE_WITH_RCCL const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); return framework::OpKernelType(framework::proto::VarType::FP32,
#else ctx.GetPlace());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
} }
}; };
...@@ -72,10 +52,38 @@ Call communication stream synchronization. ...@@ -72,10 +52,38 @@ Call communication stream synchronization.
} }
}; };
template <typename T>
class CSyncCommStreamCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto place = ctx.GetPlace();
int ring_id = ctx.Attr<int>("ring_id");
auto stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(c_sync_comm_stream, ops::CSyncCommStreamOp, REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker); ops::CSyncCommStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream,
ops::CSyncCommStreamCudaKernel<float>);
...@@ -1578,7 +1578,10 @@ void BindImperative(py::module *m_ptr) { ...@@ -1578,7 +1578,10 @@ void BindImperative(py::module *m_ptr) {
m, "NCCLParallelContext") m, "NCCLParallelContext")
.def(py::init<const imperative::ParallelStrategy &, .def(py::init<const imperative::ParallelStrategy &,
const platform::CUDAPlace &>()) const platform::CUDAPlace &>())
.def("init", [](imperative::NCCLParallelContext &self) { self.Init(); }); .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); })
.def("init_with_ring_id",
&imperative::NCCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif #endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
...@@ -1587,7 +1590,10 @@ void BindImperative(py::module *m_ptr) { ...@@ -1587,7 +1590,10 @@ void BindImperative(py::module *m_ptr) {
m, "BKCLParallelContext") m, "BKCLParallelContext")
.def(py::init<const imperative::ParallelStrategy &, .def(py::init<const imperative::ParallelStrategy &,
const platform::XPUPlace &>()) const platform::XPUPlace &>())
.def("init", [](imperative::BKCLParallelContext &self) { self.Init(); }); .def("init", [](imperative::BKCLParallelContext &self) { self.Init(); })
.def("init_with_ring_id",
&imperative::BKCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif #endif
} }
......
...@@ -119,6 +119,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -119,6 +119,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"fill_constant", {"Out"}}, {"fill_constant", {"Out"}},
{"matmul", {"Out"}}, {"matmul", {"Out"}},
{"c_broadcast", {"Out"}}, {"c_broadcast", {"Out"}},
{"c_sync_calc_stream", {"Out"}},
{"c_sync_comm_stream", {"Out"}},
{"c_allreduce_sum", {"Out"}}, {"c_allreduce_sum", {"Out"}},
{"c_allreduce_max", {"Out"}}, {"c_allreduce_max", {"Out"}},
{"c_allreduce_min", {"Out"}}, {"c_allreduce_min", {"Out"}},
......
...@@ -26,6 +26,9 @@ import paddle.fluid as fluid ...@@ -26,6 +26,9 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
__all__ = [ __all__ = [
'wait',
'new_group',
'get_group',
'broadcast', 'broadcast',
'all_reduce', 'all_reduce',
'reduce', 'reduce',
...@@ -75,30 +78,225 @@ class ReduceOp: ...@@ -75,30 +78,225 @@ class ReduceOp:
PROD = 3 PROD = 3
class _Group(): class Group():
"""The abstract representation of group.""" """
The abstract representation of group.
"""
def __init__(self, rank, rank_num): def __init__(self, rank, rank_num, id=0, ranks=[]):
self.rank = rank self.rank = rank
self.nranks = rank_num self.nranks = rank_num
self.id = id
self.ranks = ranks
def is_member(self):
if self.rank < 0:
return False
if self.nranks < 2:
return False
return True
def get_group_rank(self, rank):
if self.id == 0:
return rank
if self.is_member() and rank in self.ranks:
return self.ranks.index(rank)
else:
return -1
_global_env = None
def _get_global_env():
global _global_env
if not _global_env:
_global_env = paddle.distributed.ParallelEnv()
return _global_env
# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
def _get_group_map():
global _group_map
if not _group_map:
genv = _get_global_env()
_group_map[0] = Group(genv.rank, genv.world_size, 0)
return _group_map
def _get_global_group():
return _get_group_map()[0]
def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9)
def get_group(id=0):
"""
Get group instance by group id.
Args:
id (int): the group id
Returns:
Group: the group instance.
Examples:
.. code-block:: python
...
gid = paddle.distributed.new_group([2,4,6])
paddle.distributed.get_group(gid.id)
"""
gm = _get_group_map()
return gm[group] if group in gm else None
def new_group(ranks=None, backend=None):
"""
Creates a new distributed comminication group.
Args:
ranks (list): The global ranks of group members, list as sorted.
backend (str): The backend used to create group, only nccl is supported now.
Returns:
Group: The group instance. Nerver return None.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.distributed.init_parallel_env()
tindata = np.random.random([10, 1000]).astype('float32')
tindata = paddle.to_tensor(tindata)
gid = paddle.distributed.new_group([2,4,6])
paddle.distributed.all_reduce(tindata, group=gid, use_calc_stream=False)
"""
if not backend:
backend = 'nccl'
assert backend == 'nccl', ("backend other than nccl is not supported yet")
genv = _get_global_env()
global_rank = genv.rank
ring_id = _new_ring_id()
global _group_map
if global_rank not in ranks:
gp = Group(-1, -1, ring_id, ranks)
_group_map[ring_id] = gp
return gp
ranks = sorted(ranks)
group_rank = ranks.index(global_rank)
group_size = len(ranks)
gp = Group(group_rank, group_size, ring_id, ranks)
_group_map[ring_id] = gp
if group_size < 2:
return gp
strategy = core.ParallelStrategy()
strategy.nranks = group_size
strategy.local_rank = group_rank
strategy.trainer_endpoints = [genv.trainer_endpoints[i] for i in ranks]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id)
else:
assert False
return gp
def wait(tensor, group=None, use_calc_stream=True):
"""
wait to sync stream for group.
Args:
tensor (Tensor): The Tensor used before sync.
group (Group): The Group instance to perform sync.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to False.
Returns:
None.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.distributed.init_parallel_env()
tindata = np.random.random([10, 1000]).astype('float32')
tindata = paddle.to_tensor(tindata)
paddle.distributed.all_reduce(tindata, use_calc_stream=True)
paddle.distributed.wait(tindata)
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
if use_calc_stream:
_sync_calc_stream(tensor)
else:
_sync_comm_stream(tensor, ring_id)
def _sync_calc_stream(tensor):
if in_dygraph_mode():
return core.ops.c_sync_calc_stream(tensor, tensor)
op_type = 'c_sync_calc_stream'
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]}, )
# NOTE(chenweihang): Lazily initialized global group information
# If we initialize _default_group when import module, it will
# not update when we use spawn to run multi-process training
_default_group = None
def _sync_comm_stream(tensor, ring_id=0):
def _get_global_default_group(): if in_dygraph_mode():
global _default_group return core.ops.c_sync_comm_stream([tensor], [tensor], 'ring_id',
if _default_group is None: ring_id)
_default_group = _Group(
int(os.getenv("PADDLE_TRAINER_ID", "0")),
int(os.getenv("PADDLE_TRAINERS_NUM", "1")))
return _default_group
op_type = 'c_sync_comm_stream'
def broadcast(tensor, src, group=0): helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={'ring_id': ring_id}, )
def broadcast(tensor, src, group=None, use_calc_stream=True):
""" """
Broadcast a tensor from the source to all others. Broadcast a tensor from the source to all others.
...@@ -107,7 +305,9 @@ def broadcast(tensor, src, group=0): ...@@ -107,7 +305,9 @@ def broadcast(tensor, src, group=0):
tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
src (int): The source rank. src (int): The source rank.
group (int): The process group to work on. It is Optional. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns: Returns:
None. None.
...@@ -130,17 +330,25 @@ def broadcast(tensor, src, group=0): ...@@ -130,17 +330,25 @@ def broadcast(tensor, src, group=0):
out = data.numpy() out = data.numpy()
# [[1, 2, 3], [1, 2, 3]] # [[1, 2, 3], [1, 2, 3]]
""" """
if group is not None and not group.is_member():
return
if not isinstance(src, int):
raise ValueError("src should be int.")
ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src)
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.c_broadcast(tensor, tensor, 'root', src, return core.ops.c_broadcast(tensor, tensor, 'root', gsrc,
'use_calc_stream', True, 'ring_id', group) 'use_calc_stream', use_calc_stream,
'ring_id', ring_id)
op_type = 'c_broadcast' op_type = 'c_broadcast'
check_variable_and_dtype( check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'broadcast') 'broadcast')
if not isinstance(src, int) or not isinstance(group, int):
raise ValueError("Both the type of 'src' and 'group' for broadcast "
"should be int.")
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
...@@ -148,13 +356,13 @@ def broadcast(tensor, src, group=0): ...@@ -148,13 +356,13 @@ def broadcast(tensor, src, group=0):
inputs={'X': [tensor]}, inputs={'X': [tensor]},
outputs={'Out': [tensor]}, outputs={'Out': [tensor]},
attrs={ attrs={
'root': src, 'root': gsrc,
'use_calc_stream': True, 'use_calc_stream': use_calc_stream,
'ring_id': group, 'ring_id': ring_id,
}) })
def all_reduce(tensor, op=ReduceOp.SUM, group=0): def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
""" """
Reduce a tensor over all ranks so that all get the result. Reduce a tensor over all ranks so that all get the result.
...@@ -163,7 +371,9 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0): ...@@ -163,7 +371,9 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0):
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
group (int): Optional. The process group to work on. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns: Returns:
None. None.
...@@ -187,19 +397,25 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0): ...@@ -187,19 +397,25 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0):
out = data.numpy() out = data.numpy()
# [[5, 7, 9], [5, 7, 9]] # [[5, 7, 9], [5, 7, 9]]
""" """
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
if in_dygraph_mode(): if in_dygraph_mode():
if op == ReduceOp.SUM: if op == ReduceOp.SUM:
return core.ops.c_allreduce_sum(tensor, tensor, 'use_calc_stream', return core.ops.c_allreduce_sum(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group) use_calc_stream, 'ring_id', ring_id)
elif op == ReduceOp.MAX: elif op == ReduceOp.MAX:
return core.ops.c_allreduce_max(tensor, tensor, 'use_calc_stream', return core.ops.c_allreduce_max(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group) use_calc_stream, 'ring_id', ring_id)
elif op == ReduceOp.MIN: elif op == ReduceOp.MIN:
return core.ops.c_allreduce_min(tensor, tensor, 'use_calc_stream', return core.ops.c_allreduce_min(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group) use_calc_stream, 'ring_id', ring_id)
elif op == ReduceOp.PROD: elif op == ReduceOp.PROD:
return core.ops.c_allreduce_prod(tensor, tensor, 'use_calc_stream', return core.ops.c_allreduce_prod(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group) use_calc_stream, 'ring_id',
ring_id)
else: else:
raise ValueError("Unknown parameter: {}.".format(op)) raise ValueError("Unknown parameter: {}.".format(op))
...@@ -217,18 +433,18 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0): ...@@ -217,18 +433,18 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0):
op_type = 'c_allreduce_min' op_type = 'c_allreduce_min'
elif op == ReduceOp.PROD: elif op == ReduceOp.PROD:
op_type = 'c_allreduce_prod' op_type = 'c_allreduce_prod'
if not isinstance(group, int): if not isinstance(ring_id, int):
raise ValueError("The type of 'group' for all_reduce should be int.") raise ValueError("The type of 'ring_id' for all_reduce should be int.")
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
outputs={'Out': [tensor]}, outputs={'Out': [tensor]},
attrs={'ring_id': group, attrs={'ring_id': ring_id,
'use_calc_stream': True}) 'use_calc_stream': use_calc_stream})
def reduce(tensor, dst, op=ReduceOp.SUM, group=0): def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
""" """
Reduce a tensor to the destination from all others. Reduce a tensor to the destination from all others.
...@@ -238,7 +454,9 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0): ...@@ -238,7 +454,9 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0):
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
dst (int): The destination rank id. dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
group (int): The id of the process group to work on. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns: Returns:
None. None.
...@@ -261,20 +479,32 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0): ...@@ -261,20 +479,32 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0):
out = data.numpy() out = data.numpy()
# [[5, 7, 9], [5, 7, 9]] # [[5, 7, 9], [5, 7, 9]]
""" """
if group is not None and not group.is_member():
return
if not isinstance(dst, int):
raise ValueError("dst should be int.")
ring_id = 0 if group is None else group.id
gdst = dst if group is None else group.get_group_rank(dst)
if in_dygraph_mode(): if in_dygraph_mode():
if op == ReduceOp.SUM: if op == ReduceOp.SUM:
return core.ops.c_reduce_sum(tensor, tensor, 'use_calc_stream', return core.ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', dst) use_calc_stream, 'ring_id', ring_id,
'root_id', gdst)
elif op == ReduceOp.MAX: elif op == ReduceOp.MAX:
return core.ops.c_reduce_max(tensor, tensor, 'use_calc_stream', return core.ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', dst) use_calc_stream, 'ring_id', ring_id,
'root_id', gdst)
elif op == ReduceOp.MIN: elif op == ReduceOp.MIN:
return core.ops.c_reduce_min(tensor, tensor, 'use_calc_stream', return core.ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', dst) use_calc_stream, 'ring_id', ring_id,
'root_id', gdst)
elif op == ReduceOp.PROD: elif op == ReduceOp.PROD:
return core.ops.c_reduce_prod(tensor, tensor, 'use_calc_stream', return core.ops.c_reduce_prod(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', use_calc_stream, 'ring_id', ring_id,
dst) 'root_id', gdst)
else: else:
raise ValueError("Unknown parameter: {}.".format(op)) raise ValueError("Unknown parameter: {}.".format(op))
...@@ -295,22 +525,19 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0): ...@@ -295,22 +525,19 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0):
elif op == ReduceOp.PROD: elif op == ReduceOp.PROD:
op_type = 'c_reduce_prod' op_type = 'c_reduce_prod'
if not isinstance(dst, int) or not isinstance(group, int):
raise ValueError("Both the type of 'dst' and 'group' for reduce "
"should be int.")
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
outputs={'Out': [tensor]}, outputs={'Out': [tensor]},
attrs={ attrs={
'ring_id': group, 'ring_id': ring_id,
'use_calc_stream': True, 'use_calc_stream': use_calc_stream,
'root_id': dst, 'root_id': gdst,
}) })
def all_gather(tensor_list, tensor, group=0): def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
""" """
Gather tensors from all participators and all get the result. Gather tensors from all participators and all get the result.
...@@ -320,7 +547,9 @@ def all_gather(tensor_list, tensor, group=0): ...@@ -320,7 +547,9 @@ def all_gather(tensor_list, tensor, group=0):
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
tensor (Tensor): The Tensor to send. Its data type tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
group (int): The id of the process group to work on. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns: Returns:
None. None.
...@@ -348,13 +577,19 @@ def all_gather(tensor_list, tensor, group=0): ...@@ -348,13 +577,19 @@ def all_gather(tensor_list, tensor, group=0):
data2 = paddle.to_tensor(np_data2) data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_gather(tensor_list, data2) paddle.distributed.all_gather(tensor_list, data2)
""" """
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
nranks = _get_global_group().nranks if group is None else group.nranks
op_type = 'c_allgather' op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype) out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
_default_group = _get_global_default_group()
if in_dygraph_mode(): if in_dygraph_mode():
core.ops.c_allgather(tensor, out, 'use_calc_stream', True, 'ring_id', core.ops.c_allgather(tensor, out, 'use_calc_stream', use_calc_stream,
group, 'nranks', _default_group.nranks) 'ring_id', ring_id, 'nranks', nranks)
else: else:
if not isinstance(tensor_list, list): if not isinstance(tensor_list, list):
raise ValueError("The type of 'tensor_list' for all_gather " raise ValueError("The type of 'tensor_list' for all_gather "
...@@ -367,23 +602,20 @@ def all_gather(tensor_list, tensor, group=0): ...@@ -367,23 +602,20 @@ def all_gather(tensor_list, tensor, group=0):
check_variable_and_dtype( check_variable_and_dtype(
tensor, 'tensor', tensor, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather') ['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather')
if not isinstance(group, int):
raise ValueError("The type of 'group' for all_gather "
"should be int.")
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={ attrs={
'ring_id': group, 'ring_id': ring_id,
'use_calc_stream': True, 'use_calc_stream': use_calc_stream,
'nranks': _default_group.nranks 'nranks': nranks
}) })
tensor_list.extend(paddle.split(out, _default_group.nranks, 0)) tensor_list.extend(paddle.split(out, nranks, 0))
def scatter(tensor, tensor_list=None, src=0, group=0): def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
""" """
Scatter a tensor to all participators. Scatter a tensor to all participators.
...@@ -394,7 +626,9 @@ def scatter(tensor, tensor_list=None, src=0, group=0): ...@@ -394,7 +626,9 @@ def scatter(tensor, tensor_list=None, src=0, group=0):
tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
src (int): The source rank id. src (int): The source rank id.
group (int): The id of the process group to work on. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns: Returns:
None. None.
...@@ -422,45 +656,51 @@ def scatter(tensor, tensor_list=None, src=0, group=0): ...@@ -422,45 +656,51 @@ def scatter(tensor, tensor_list=None, src=0, group=0):
paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1) paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
out = data1.numpy() out = data1.numpy()
""" """
if group is not None and not group.is_member():
return
if not isinstance(src, int):
raise ValueError("src should be int.")
ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src)
rank = _get_global_group().rank if group is None else group.rank
nranks = _get_global_group().nranks if group is None else group.nranks
op_type = 'c_scatter' op_type = 'c_scatter'
_default_group = _get_global_default_group()
rank = _default_group.rank if rank != gsrc:
nranks = _default_group.nranks
if rank != src:
tensor_list = [] tensor_list = []
for _ in range(nranks): for _ in range(nranks):
tensor_list.append(tensor) tensor_list.append(tensor)
temp = paddle.concat(tensor_list, axis=0) temp = paddle.concat(tensor_list, axis=0)
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.c_scatter(temp, tensor, 'use_calc_stream', True, return core.ops.c_scatter(temp, tensor, 'use_calc_stream',
'ring_id', group, 'nranks', use_calc_stream, 'ring_id', ring_id, 'nranks',
_default_group.nranks, 'root', src) nranks, 'root', gsrc)
check_variable_and_dtype( check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'scatter') 'scatter')
if not isinstance(group, int) or not isinstance(src, int):
raise ValueError("Both the type of 'src' and 'group' for scatter "
"should be int.")
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [temp]}, inputs={'X': [temp]},
outputs={'Out': [tensor]}, outputs={'Out': [tensor]},
attrs={ attrs={
'ring_id': group, 'ring_id': ring_id,
'root': src, 'root': gsrc,
'use_calc_stream': True, 'use_calc_stream': use_calc_stream,
'nranks': nranks, 'nranks': nranks,
}) })
def barrier(group=0): def barrier(group=None):
""" """
Barrier among all participators in the group. Barrier among all participators in the group.
Args: Args:
group (int): The id of the process group to work on. group (Group): The group instance return by new_group or None for global default group.
Returns: Returns:
None. None.
...@@ -475,18 +715,23 @@ def barrier(group=0): ...@@ -475,18 +715,23 @@ def barrier(group=0):
init_parallel_env() init_parallel_env()
paddle.distributed.barrier() paddle.distributed.barrier()
""" """
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
op_type = 'barrier' op_type = 'barrier'
temp = fill_constant([1], dtype="int32", value="1") temp = fill_constant([1], dtype="int32", value="1")
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.barrier(temp, temp, 'ring_id', group) return core.ops.barrier(temp, temp, 'ring_id', ring_id)
if not isinstance(group, int): if not isinstance(ring_id, int):
raise ValueError("The type of 'group' for barrier must be int.") raise ValueError("The type of 'group' for barrier must be int.")
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [temp]}, inputs={'X': [temp]},
outputs={'Out': [temp]}, outputs={'Out': [temp]},
attrs={'ring_id': group}) attrs={'ring_id': ring_id})
def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr, def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr,
...@@ -515,10 +760,10 @@ def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr, ...@@ -515,10 +760,10 @@ def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr,
if gather_out: if gather_out:
if axis == 0: if axis == 0:
paddle.distributed.all_reduce(linear_out, group=0) paddle.distributed.all_reduce(linear_out)
else: else:
output = [] output = []
paddle.distributed.all_gather(output, linear_out, group=0) paddle.distributed.all_gather(output, linear_out)
linear_out = paddle.concat(output, axis=len(linear_out.shape) - 1) linear_out = paddle.concat(output, axis=len(linear_out.shape) - 1)
return linear_out return linear_out
...@@ -559,7 +804,7 @@ def _parallel_embedding(x, per_part_embeddings, origin_size, param_attr, ...@@ -559,7 +804,7 @@ def _parallel_embedding(x, per_part_embeddings, origin_size, param_attr,
main_block = paddle.static.default_main_program().global_block() main_block = paddle.static.default_main_program().global_block()
startup_block.vars[embedding.weight.name].is_distributed = True startup_block.vars[embedding.weight.name].is_distributed = True
main_block.vars[embedding.weight.name].is_distributed = True main_block.vars[embedding.weight.name].is_distributed = True
paddle.distributed.all_reduce(emb_out, group=0) paddle.distributed.all_reduce(emb_out, group=None)
return emb_out return emb_out
...@@ -584,7 +829,7 @@ def split(x, ...@@ -584,7 +829,7 @@ def split(x,
With parallel embedding, the weight is split into num_partitions partitions, each With parallel embedding, the weight is split into num_partitions partitions, each
of which is a matrix with (N/num_partitions + 1) rows and M column where the last of which is a matrix with (N/num_partitions + 1) rows and M column where the last
row as the padding idx. row as the padding idx.
Suppose we split the NxM weight into two partitons on device_0 and device_1 Suppose we split the NxM weight into two partitons on device_0 and device_1
respectively. Then, one each device, the final weight has (N/2 + 1) rows with the respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1] index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
......
...@@ -82,6 +82,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) ...@@ -82,6 +82,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_barrier_api) LIST(REMOVE_ITEM TEST_OPS test_collective_barrier_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api) LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api)
LIST(REMOVE_ITEM TEST_OPS test_new_group_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api) LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api) LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_wait) LIST(REMOVE_ITEM TEST_OPS test_collective_wait)
...@@ -177,6 +178,7 @@ endif() ...@@ -177,6 +178,7 @@ endif()
if ((NOT WITH_NCCL) AND (NOT WITH_RCCL)) if ((NOT WITH_NCCL) AND (NOT WITH_RCCL))
list(REMOVE_ITEM TEST_OPS test_imperative_group) list(REMOVE_ITEM TEST_OPS test_imperative_group)
LIST(REMOVE_ITEM TEST_OPS test_new_group_api)
endif() endif()
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
...@@ -518,6 +520,7 @@ if(WITH_DISTRIBUTE) ...@@ -518,6 +520,7 @@ if(WITH_DISTRIBUTE)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
bash_test_modules(test_c_comm_init_op START_BASH test_c_comm_init_op.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) bash_test_modules(test_c_comm_init_op START_BASH test_c_comm_init_op.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
py_test_modules(test_launch_coverage MODULES test_launch_coverage) py_test_modules(test_launch_coverage MODULES test_launch_coverage)
bash_test_modules(test_new_group START_BASH test_new_group.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
endif() endif()
bash_test_modules(test_fleetrun START_BASH test_fleetrun.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) bash_test_modules(test_fleetrun START_BASH test_fleetrun.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
...@@ -831,6 +834,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) ...@@ -831,6 +834,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_new_group_api PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
set_tests_properties(test_pipeline PROPERTIES TIMEOUT 120) set_tests_properties(test_pipeline PROPERTIES TIMEOUT 120)
endif() endif()
...@@ -853,6 +857,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) ...@@ -853,6 +857,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
test_collective_barrier_api test_collective_barrier_api
test_collective_reduce_api test_collective_reduce_api
test_collective_allreduce_api test_collective_allreduce_api
test_new_group_api
test_collective_broadcast_api test_collective_broadcast_api
test_collective_allgather_api test_collective_allgather_api
PROPERTIES LABELS "RUN_TYPE=DIST") PROPERTIES LABELS "RUN_TYPE=DIST")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveAllreduceNewGroupAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
gp = paddle.distributed.new_group([0, 1])
paddle.distributed.all_reduce(
tindata, group=gp, use_calc_stream=False)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveAllreduceNewGroupAPI, "allreduce")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import paddle
class TestNewGroupAPI(object):
def __init__(self):
paddle.distributed.init_parallel_env()
d1 = np.array([1, 2, 3])
d2 = np.array([2, 3, 4])
self.tensor1 = paddle.to_tensor(d1)
self.tensor2 = paddle.to_tensor(d2)
def test_all(self):
gp = paddle.distributed.new_group([0, 1])
print("test new group api ok")
tmp = np.array([0, 0, 0])
result = paddle.to_tensor(tmp)
paddle.distributed.scatter(
result, [self.tensor2, self.tensor1],
src=0,
group=gp,
use_calc_stream=True)
if gp.rank == 0:
assert np.array_equal(result, self.tensor2)
elif gp.rank == 1:
assert np.array_equal(result, self.tensor1)
print("test scatter api ok")
paddle.distributed.broadcast(
result, src=1, group=gp, use_calc_stream=True)
assert np.array_equal(result, self.tensor1)
print("test broadcast api ok")
paddle.distributed.reduce(result, dst=0, group=gp, use_calc_stream=True)
if gp.rank == 0:
assert np.array_equal(result,
paddle.add(self.tensor1, self.tensor1))
elif gp.rank == 1:
assert np.array_equal(result, self.tensor1)
print("test reduce api ok")
paddle.distributed.all_reduce(result, use_calc_stream=True)
assert np.array_equal(
result,
paddle.add(paddle.add(self.tensor1, self.tensor1), self.tensor1))
print("test all_reduce api ok")
paddle.distributed.wait(result, gp, use_calc_stream=True)
paddle.distributed.wait(result, gp, use_calc_stream=False)
print("test wait api ok")
result = []
paddle.distributed.all_gather(
result, self.tensor1, group=gp, use_calc_stream=True)
assert np.array_equal(result[0], self.tensor1)
assert np.array_equal(result[1], self.tensor1)
print("test all_gather api ok")
paddle.distributed.barrier(group=gp)
print("test barrier api ok")
return
if __name__ == "__main__":
gpt = TestNewGroupAPI()
gpt.test_all()
#!/bin/bash
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 new_group.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_api_base import TestDistBase
paddle.enable_static()
class TestCollectiveAllreduceAPI(TestDistBase):
def _setup_config(self):
pass
def test_allreduce_nccl(self):
self.check_with_place("collective_allreduce_new_group_api.py",
"allreduce", "nccl")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册