未验证 提交 db41b742 编写于 作者: L lilong12 提交者: GitHub

add alltoall api (#32507)

* add alltoall api, test=develop
上级 31326950
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/alltoall_op.h"
namespace paddle {
namespace operators {
class AllToAllOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AllToAll");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AllToAll");
int ring_id = ctx->Attrs().Get<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for alltoall op must be non-negative.", ring_id));
framework::DDim dim = ctx->GetInputDim("X");
if (dim[0] < 0) dim[0] = -1;
ctx->SetOutputDim("Out", dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class AllToAllOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor send.");
AddOutput("Out", "(Tensor) the result of alltoall.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
AllToAll Operator
Scatter tensors from all participators to all participators.
)DOC");
}
};
template <typename T>
class AllToAllOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("alltoall");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(AllToAllInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(alltoall, ops::AllToAllOp, ops::AllToAllOpMaker,
ops::AllToAllOpGradMaker<paddle::framework::OpDesc>,
ops::AllToAllOpGradMaker<paddle::imperative::OpBase>,
ops::AllToAllInplaceInferer)
REGISTER_OP_CPU_KERNEL(alltoall, ops::AllToAllOpCPUKernel<float>,
ops::AllToAllOpCPUKernel<double>,
ops::AllToAllOpCPUKernel<int>,
ops::AllToAllOpCPUKernel<int64_t>,
ops::AllToAllOpCPUKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/alltoall_op.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
int send_numel = x->numel();
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
int ring_id = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for alltoall op must be non-negative.", ring_id));
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
framework::DDim x_dims = x->dims();
framework::DDim out_dims(x_dims);
PADDLE_ENFORCE_EQ(
x_dims[0] % nranks, 0,
platform::errors::InvalidArgument(
"The first dimension size (%d) of the input tensor must be "
"divisible by the number of ranks (%d).",
x_dims[0], nranks));
auto send_buf = x->data<T>();
auto recv_buf = out->mutable_data<T>(out_dims, place);
size_t offset = 0;
send_numel /= nranks;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < nranks; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
send_buf + offset, send_numel, dtype, i, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
recv_buf + offset, send_numel, dtype, i, comm->comm(), stream));
offset += send_numel;
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
#else
PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
#endif
#else
PADDLE_THROW(
platform::errors::Unavailable("PaddlePaddle should compile with GPU."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel<float>,
ops::AllToAllOpCUDAKernel<double>,
ops::AllToAllOpCUDAKernel<int>,
ops::AllToAllOpCUDAKernel<int64_t>,
ops::AllToAllOpCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class AllToAllOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support alltoall for cpu kernel now."));
}
};
} // namespace operators
} // namespace paddle
......@@ -36,6 +36,7 @@ __all__ = [
'scatter',
'barrier',
'split',
'alltoall',
'ReduceOp',
'send',
'recv',
......@@ -1178,6 +1179,77 @@ def split(x,
return linear_out
def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
"""
Scatter tensors in in_tensor_list to all participators and gather the result tensors in out_tensor_list.
Args:
in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
out_tensor_list (Tensor): A list of output Tensors. The data type of its elements should be the same as the
data type of the input Tensors.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
Returns:
None.
Examples:
.. code-block:: python
# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
init_parallel_env()
out_tensor_list = []
if paddle.distributed.ParallelEnv().rank == 0:
np_data1 = np.array([[1, 2, 3], [4, 5, 6]])
np_data2 = np.array([[7, 8, 9], [10, 11, 12]])
else:
np_data1 = np.array([[13, 14, 15], [16, 17, 18]])
np_data2 = np.array([[19, 20, 21], [22, 23, 24]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_to_all([data1, data2], out_tensor_list)
# out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]]
# out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]]
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
op_type = 'alltoall'
temp = paddle.concat(in_tensor_list, axis=0)
helper = LayerHelper(op_type, **locals())
nranks = len(in_tensor_list)
out = helper.create_variable_for_type_inference(
dtype=in_tensor_list[0].dtype)
if in_dygraph_mode():
core.ops.alltoall_(temp, 'use_calc_stream', use_calc_stream, 'ring_id',
ring_id)
else:
if not isinstance(in_tensor_list, list):
raise ValueError("The type of 'in_tensor_list' for all_to_all "
"should be list.")
for elem in in_tensor_list:
check_variable_and_dtype(
elem, 'in_tensor_list',
['float16', 'float32', 'float64', 'int32', 'int64'],
'all_to_all')
if not isinstance(out_tensor_list, list):
raise ValueError("The type of 'out_tensor_list' for all_to_all "
"should be list.")
if len(out_tensor_list) != 0:
raise ValueError("The 'out_tensor_list' for all_to_all "
"must be an empty list.")
helper.append_op(
type=op_type,
inputs={'X': [temp]},
outputs={'Out': [out]},
attrs={
'ring_id': group,
'use_calc_stream': use_calc_stream,
})
out_tensor_list.extend(paddle.split(out, nranks, 0))
def send(tensor, dst=0, group=None, use_calc_stream=True):
"""
Send a tensor to the receiver.
......
......@@ -96,6 +96,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_new_group_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_alltoall_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_wait)
LIST(REMOVE_ITEM TEST_OPS test_memcpy_op)
......@@ -872,6 +873,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
endif()
if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_sendrecv_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120)
......@@ -907,6 +909,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
test_new_group_api
test_collective_broadcast_api
test_collective_allgather_api
test_collective_alltoall_api
PROPERTIES LABELS "RUN_TYPE=DIST")
endif()
if(WITH_GPU OR WITH_ROCM)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
tindata = paddle.split(tindata, 2, axis=0)
tout_data = []
paddle.distributed.alltoall(tindata, tout_data)
return tout_data
if __name__ == "__main__":
runtime_main(TestCollectiveAllToAllAPI, "alltoall")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_api_base import TestDistBase
paddle.enable_static()
class TestCollectiveAllToAllAPI(TestDistBase):
def _setup_config(self):
pass
def test_alltoall_nccl(self):
self.check_with_place("collective_alltoall_api.py", "alltoall", "nccl")
if __name__ == '__main__':
unittest.main()
......@@ -277,6 +277,19 @@ class TestDistBase(unittest.TestCase):
self.assertTrue(
np.allclose(
result_data, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "alltoall":
need_result1 = np.vstack((input1[0:input1.shape[0] // 2, :],
input2[0:input2.shape[0] // 2, :]))
need_result2 = np.vstack((input1[input1.shape[0] // 2:, :],
input2[input2.shape[0] // 2:, :]))
tr0_out = np.vstack(tr0_out)
tr1_out = np.vstack(tr1_out)
self.assertTrue(
np.allclose(
tr0_out, need_result1, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out, need_result2, rtol=1e-05, atol=1e-05))
elif col_type == "sendrecv":
result_data = tr1_out[0]
self.assertTrue(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册