diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 4ef12d57433281c72e8d4163a14e4da69cbbc326..92514691d03e5fc0eb0e1c81fb28e7e440d2462d 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -156,6 +156,16 @@ optional : bias backward : depthwise_conv2d_transpose_grad +- op : dist_concat + args : (Tensor x, int ring_id = 0, int nranks = 1) + output : Tensor(out) + infer_meta : + func : DistConcatInferMeta + param: [x, nranks] + kernel : + func : dist_concat + param: [x, nranks] + - op : einsum args : (Tensor[] x, str equation) output : Tensor(out), Tensor[](inner_cache){x.size()}, Tensor[](xshape){x.size()} diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 9b69dcfdd8008645a096345bc73862a0c2fca451..0d2a7ca8d26c0ebaea2be10bc24dd0f4ea038cbe 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -844,6 +844,14 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) { out->set_dtype(alpha.dtype()); } +void DistConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { + auto dim = x.dims(); + dim[dim.size() - 1] = dim[dim.size() - 1] * nranks; + if (dim[dim.size() - 1] < 0) dim[dim.size() - 1] = -1; + out->set_dtype(x.dtype()); + out->set_dims(dim); +} + void DistReduceInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_dtype(x.dtype()); out->set_dims(x.dims()); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index d03caa048b591c5d4be86a0e399c4d1da87fc208..af2531af6dae2afe38ac48dd05305f7ad3855cff 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -147,6 +147,8 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out); void DistBroadcastInferMeta(const MetaTensor& x, MetaTensor* out); +void DistConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out); + void DistReduceInferMeta(const MetaTensor& x, MetaTensor* out); void EmbeddingGradSparseInferMeta(const MetaTensor& x, MetaTensor* out); diff --git a/paddle/phi/kernels/dist_concat_kernel.h b/paddle/phi/kernels/dist_concat_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b1bdb0d3262812bd5c66b6ddb8451d66943ad814 --- /dev/null +++ b/paddle/phi/kernels/dist_concat_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DistConcatKernel(const Context& dev_ctx, + const DenseTensor& x, + int nranks, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/dist_concat_kernel.cu b/paddle/phi/kernels/gpu/dist_concat_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..75500f06299b36dc5da623fc0b3399f6bab8632e --- /dev/null +++ b/paddle/phi/kernels/gpu/dist_concat_kernel.cu @@ -0,0 +1,106 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/dist_concat_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#endif + +namespace phi { + +template +void DistConcatKernel(const Context& dev_ctx, + const DenseTensor& x, + int nranks, + DenseTensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + DenseTensor temp_out; + auto temp_out_dims = x.dims(); + temp_out_dims[0] *= nranks; + temp_out.Resize(temp_out_dims); + dev_ctx.template Alloc(&temp_out); + + auto comm_ctx = + static_cast(dev_ctx.GetCommContext()); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + errors::Unavailable("NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + PADDLE_ENFORCE_EQ( + nranks, + comm_ctx->GetSize(), + errors::InvalidArgument( + "nranks: %s should equal to %s", nranks, comm_ctx->GetSize())); + + gpuStream_t stream = dev_ctx.stream(); + comm_ctx->AllGather(&temp_out, x, stream); + + std::vector inputs; + int axis = x.dims().size() - 1; + auto out_dims = x.dims(); + out_dims[out_dims.size() - 1] *= nranks; + int rows_per_tensor = x.dims()[0]; + int offset = 0; + for (int i = 0; i < nranks; i++) { + DenseTensor temp = + temp_out.Slice(static_cast(offset), + static_cast(offset + rows_per_tensor)); + inputs.emplace_back(temp); + offset += rows_per_tensor; + } + phi::funcs::ConcatFunctor functor; + out->Resize(out_dims); + dev_ctx.template Alloc(out); + functor(dev_ctx, inputs, axis, out); +#else + PADDLE_THROW( + errors::PreconditionNotMet("PaddlePaddle should compile with GPU.")); +#endif +} + +} // namespace phi + +#if NCCL_VERSION_CODE >= 21000 +PD_REGISTER_KERNEL(dist_concat, + GPU, + ALL_LAYOUT, + phi::DistConcatKernel, + float, + double, + int, + uint8_t, + int8_t, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(dist_concat, + GPU, + ALL_LAYOUT, + phi::DistConcatKernel, + float, + double, + int, + uint8_t, + int8_t, + int64_t, + bool, + phi::dtype::float16) {} +#endif diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 14fa116c874d1d162be5e09529129f6750d0a5e7..604aeeb9c0b83bb51f1218d67d1b7e3a851ce18b 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -323,6 +323,11 @@ def is_available(): def _init_parallel_env(backend): master_endpoint = os.getenv("PADDLE_MASTER", None) + if master_endpoint is None: + master_endpoint = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0] + assert ( + master_endpoint is not None + ), "Please set PADDLE_MASTER enviroment variable." if master_endpoint: master_addr = master_endpoint.split(":")[0] master_port = int(master_endpoint.split(":")[1]) diff --git a/test/collective/CMakeLists.txt b/test/collective/CMakeLists.txt index 3f2aed73b29eadf51c0ae737068d51f4da00850e..05732504bf9684c73ab497ce5d281c7510b95aab 100644 --- a/test/collective/CMakeLists.txt +++ b/test/collective/CMakeLists.txt @@ -151,6 +151,13 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) set_tests_properties(test_collective_broadcast_object_list_api PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_collective_concat_api MODULES test_collective_concat_api ENVS + "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_collective_concat_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( test_collective_cpu_barrier_with_gloo MODULES diff --git a/test/collective/collective_concat_api.py b/test/collective/collective_concat_api.py new file mode 100644 index 0000000000000000000000000000000000000000..fab9711a918672b9cc04bae2d8d3f8540c87e920 --- /dev/null +++ b/test/collective/collective_concat_api.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from legacy_test.test_collective_api_base import ( + TestCollectiveAPIRunnerBase, + runtime_main, +) + +import paddle +from paddle import fluid, framework +from paddle.fluid import data_feeder + +paddle.enable_static() + + +def concat_new(tensor, group=None): + op_type = 'dist_concat' + data_feeder.check_variable_and_dtype( + tensor, + 'tensor', + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'int8', + 'uint8', + 'bool', + 'uint16', + ], + op_type, + ) + + helper = framework.LayerHelper(op_type, **locals()) + ring_id = 0 if group is None else group.id + nranks = 2 + + out = helper.create_variable_for_type_inference(dtype=tensor.dtype) + helper.append_op( + type=op_type, + inputs={'x': [tensor]}, + outputs={'out': [out]}, + attrs={ + 'ring_id': ring_id, + 'nranks': nranks, + }, + ) + return out + + +class TestCollectiveConcatAPI(TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + pass + + def get_model_new( + self, main_prog, startup_program, rank, dtype=None, reduce_type=None + ): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.static.data( + name="tindata", shape=[10, 1000], dtype=dtype + ) + tindata.desc.set_need_check_feed(False) + toutdata = concat_new(tindata) + return [toutdata] + + +if __name__ == "__main__": + runtime_main(TestCollectiveConcatAPI, "concat") diff --git a/test/collective/test_collective_concat_api.py b/test/collective/test_collective_concat_api.py new file mode 100644 index 0000000000000000000000000000000000000000..94cbc3ea2228157bc91c5e107b43293e82c9ebe1 --- /dev/null +++ b/test/collective/test_collective_concat_api.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from legacy_test.test_collective_api_base import TestDistBase + +import paddle + +paddle.enable_static() + + +class TestCollectiveConcatAPI(TestDistBase): + def _setup_config(self): + pass + + def test_concat_with_comm_context(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if self._nccl_version >= 21000: + dtypes_to_test.append("bfloat16") + for dtype in dtypes_to_test: + self.check_with_place( + "collective_concat_api.py", + "dist_concat", + "nccl", + dtype=dtype, + need_envs={"USE_COMM_CONTEXT": "1"}, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_collective_api_base.py b/test/legacy_test/test_collective_api_base.py index 383ef168bfce8eb4da12ecb2c9c15d25a9ce2d0e..f99cf9378eafdcef4938ad7b2d8089dcea40d2c3 100644 --- a/test/legacy_test/test_collective_api_base.py +++ b/test/legacy_test/test_collective_api_base.py @@ -494,6 +494,15 @@ class TestDistBase(unittest.TestCase): np.testing.assert_allclose( result_data, need_result, rtol=1e-05, atol=1e-05 ) + elif col_type == "dist_concat": + result_data = tr0_out[0] + need_result = np.concatenate((input1, input2), axis=1) + np.testing.assert_allclose( + result_data, need_result, rtol=1e-05, atol=1e-05 + ) + np.testing.assert_allclose( + result_data, need_result, rtol=1e-05, atol=1e-05 + ) elif col_type == "alltoall": need_result1 = np.vstack( ( diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index b2082c3e7eb088c89b3768065f7625e081f41b71..c22938e27d150589e6cbce6cb1f09d23076be9da 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -528,6 +528,7 @@ HIGH_PARALLEL_JOB_NEW = [ 'test_collective_reduce_api', 'test_multiprocess_dataloader_exception', 'test_collective_allgather_api', + 'test_collective_concat_api', 'test_dist_fleet_ps10', 'test_dist_sparse_tensor_load_rmsprop', 'test_collective_split_embedding_none_divisible',