diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index bdb8a763a91fd7e19eb9ca63f6c838920c1c8875..96c49b4170519e09eb375038aadfc30d83b59bf3 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -26,6 +26,9 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" +#endif namespace py = pybind11; @@ -107,6 +110,11 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) { } void BindAutoParallel(py::module *m) { +#ifdef PADDLE_WITH_DISTRIBUTE + py::class_(*m, "RToSReshardFunction") + .def(py::init<>()); +#endif + py::class_(*m, "ProcessMesh") .def(py::init<>()) .def(py::init &, diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index db639bba5f400a1f298c6f081881a803c1f0207e..d4af259a5906cd2b297f8caadea2cfcd5e154b3a 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -1,5 +1,18 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto) +set(DISTRIBUTED_SRCS "") + +if(WITH_DISTRIBUTE) + list( + APPEND + DISTRIBUTED_SRCS + dist_tensor.cc + reshard_function.cc + reshard_split_functor.cc + reshard_utils.cc + r_to_s_reshard_function.cc) +endif() + collect_srcs( core_srcs SRCS @@ -7,4 +20,4 @@ collect_srcs( process_mesh.cc dist_attr.cc dist_mapper.cc - dist_tensor.cc) + ${DISTRIBUTED_SRCS}) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index eb3a6dbbe3e665965075297716d679eafe3d9b52..63a7438a6ae7ab21b0957a7e0313d190f5a87195 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -31,7 +31,7 @@ class DistTensor final public: /// \brief Construct a dist tensor and allocate space. /// \param a The allocator used to allocate space. - /// \param meta The meta data of dense tensor. + /// \param meta The meta data of dist tensor. DistTensor(Allocator* a, const DenseTensorMeta& meta, const std::shared_ptr& dist_attr) diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9db48e631cff0ce5f52c631143d0641d364ee60 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" + +#include "glog/logging.h" +#include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/kernel_factory.h" + +namespace phi { +namespace distributed { + +bool RToSReshardFunction::IsSuitable( + const DistTensor& in, + const std::shared_ptr& out_dist_attr) { + bool flag = true; + const auto& in_dist_attr = in.dist_attr(); + + const auto& in_dims_mapping = in_dist_attr->dims_mapping(); + const auto& out_dims_mapping = out_dist_attr->dims_mapping(); + + flag &= IsDimsMappingReplicated(in_dims_mapping); + flag &= IsDimsMappingShard(out_dims_mapping); + + const auto& in_process_mesh = in_dist_attr->process_mesh(); + const auto& out_process_mesh = out_dist_attr->process_mesh(); + + flag &= (in_process_mesh.ndim() == 1); + flag &= (out_process_mesh.ndim() == 1); + flag &= (in_process_mesh == out_process_mesh); + + return flag; +} + +std::shared_ptr RToSReshardFunction::Eval( + const phi::DeviceContext& dev_ctx, + const DistTensor& in, + const std::shared_ptr& out_dist_attr) { + const auto& out_dims_mapping = out_dist_attr->dims_mapping(); + const auto& out_process_mesh = out_dist_attr->process_mesh(); + const DenseTensor& in_physical_tensor_cur_rank = in.value(); + + DenseTensor out_physical_tensor_cur_rank; + + std::map split_axis_to_mesh_axis = + GetSplitAxisWithDimsMapping(out_dims_mapping); + std::vector coord_in_mesh = GetCurRankCoordInMesh(out_process_mesh); + + int64_t split_axis = split_axis_to_mesh_axis.begin()->first; + int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second; + + PADDLE_ENFORCE_LT( + mesh_axis, + out_process_mesh.ndim(), + phi::errors::OutOfRange( + "The mesh axis %lld exceed the size of process mesh %lld.", + mesh_axis, + out_process_mesh.ndim())); + + int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; + VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis + << ". Split will use axis " << mesh_axis << " of process_mesh." + << " There will have " << num_of_process + << " process participate in."; + + // TODO(liyurui): Consider the tensor can not be balanced split, + // for example, the shape of tensor is {6} but want to split it by 4 + // process. + IntArray sections(std::vector( + num_of_process, in.dims()[split_axis] / num_of_process)); + + std::vector split_out_vec = ReshardSplitFunctor( + dev_ctx, in_physical_tensor_cur_rank, sections, split_axis); + + VLOG(3) << "The current process will remain the idx " + << coord_in_mesh[mesh_axis] << " piece of tensor"; + out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]]; + + return std::make_shared( + std::make_shared(out_physical_tensor_cur_rank), + out_dist_attr); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h new file mode 100644 index 0000000000000000000000000000000000000000..61b77820297e4342c7b3d08a622df49fe20930fa --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h @@ -0,0 +1,38 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" + +namespace phi { +namespace distributed { + +class RToSReshardFunction final : public ReshardFunction { + public: + RToSReshardFunction() = default; + ~RToSReshardFunction() = default; + + bool IsSuitable( + const DistTensor& in, + const std::shared_ptr& out_dist_attr) override; + + std::shared_ptr Eval( + const DeviceContext& dev_ctx, + const DistTensor& in, + const std::shared_ptr& out_dist_attr) override; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..04bbc4a09fe1f90663b3c5ab96ef5bf464061bd5 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" + +namespace phi { +namespace distributed {} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard_function.h new file mode 100644 index 0000000000000000000000000000000000000000..2c8574ca376ce8194e03d094c666fa8d9692d978 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.h @@ -0,0 +1,45 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace phi { +class DeviceContext; + +namespace distributed { +namespace auto_parallel { +class TensorDistAttr; +} // namespace auto_parallel + +class DistTensor; +using auto_parallel::TensorDistAttr; + +class ReshardFunction { + public: + ReshardFunction() = default; + virtual ~ReshardFunction() = default; + + virtual bool IsSuitable( + const DistTensor& in, + const std::shared_ptr& out_dist_attr) = 0; + + virtual std::shared_ptr Eval( + const DeviceContext& dev_ctx, + const DistTensor& in, + const std::shared_ptr& out_dist_attr) = 0; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..189738b81367fa56f852dec75b5a76e8adec8c2b --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/split_kernel.h" + +namespace phi { +namespace distributed { + +std::vector ReshardSplitFunctor(const DeviceContext& dev_ctx, + const DenseTensor& input, + const IntArray& sections, + int64_t axis) { + size_t out_number = sections.size(); + std::vector result(out_number); + + std::vector out_meta; + std::vector out_meta_ptr; + + out_meta.reserve(out_number); + out_meta_ptr.reserve(out_number); + for (size_t i = 0; i < out_number; ++i) { + out_meta.emplace_back(result[i]); + out_meta_ptr.emplace_back(&out_meta.back()); + } + SplitInferMeta(phi::MetaTensor(input), sections, axis, out_meta_ptr); + + std::vector outs; + for (size_t i = 0; i < out_number; ++i) { + outs.emplace_back(&result[i]); + } + + if (phi::CPUContext::classof(&dev_ctx)) { + PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] { + SplitKernel( + static_cast(dev_ctx), + input, + sections, + axis, + outs); + })); + return result; + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (phi::GPUContext::classof(&dev_ctx)) { + PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] { + SplitKernel( + static_cast(dev_ctx), + input, + sections, + axis, + outs); + })); + return result; + } +#endif + PADDLE_THROW(phi::errors::Unimplemented( + "The split in reshard only supported on CPU and GPU for now.")); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..87b9f2301ad0baaccfa9aa8fe6283ff1afaff76a --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h @@ -0,0 +1,32 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/phi/common/int_array.h" + +namespace phi { +class DeviceContext; +class DenseTensor; + +namespace distributed { +std::vector ReshardSplitFunctor(const DeviceContext& dev_ctx, + const DenseTensor& input, + const IntArray& sections, + int64_t axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..b777b53c2304384e8eedb33d6206a638f86efce9 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" + +#include +#include "glog/logging.h" +#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" + +namespace phi { +namespace distributed { + +bool IsDimsMappingShard(const std::vector& dims_mapping) { + return std::any_of(dims_mapping.begin(), + dims_mapping.end(), + [](int64_t value) { return value != -1; }); +} + +bool IsDimsMappingReplicated(const std::vector& dims_mapping) { + return std::all_of(dims_mapping.begin(), + dims_mapping.end(), + [](int64_t value) { return value == -1; }); +} + +int64_t GetCurGlobalRank() { + const char* cur_rank = std::getenv("PADDLE_TRAINER_ID"); + PADDLE_ENFORCE_NOT_NULL( + cur_rank, + phi::errors::NotFound( + "The environment variable 'PADDLE_TRAINER_ID' cannot be found.")); + return std::atoi(cur_rank); +} + +std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh) { + const auto& process_shape = process_mesh.shape(); + const auto& process_ids = process_mesh.process_ids(); + int64_t ndims_mesh = process_shape.size(); + int64_t cur_global_rank = GetCurGlobalRank(); + + VLOG(3) << "Searching current global rank " << cur_global_rank + << " in process_mesh " << process_mesh; + + auto iter = + std::find(process_ids.begin(), process_ids.end(), cur_global_rank); + PADDLE_ENFORCE_NE( + iter, + process_ids.end(), + phi::errors::NotFound("Rank %lld cannot be found in process_mesh", + cur_global_rank)); + + int64_t flat_idx_in_mesh = iter - process_ids.begin(); + + std::vector coord(ndims_mesh, -1); + for (int64_t i = ndims_mesh - 1; i >= 0; --i) { + coord[i] = flat_idx_in_mesh % process_shape[i]; + flat_idx_in_mesh /= process_shape[i]; + } + return coord; +} + +std::map GetSplitAxisWithDimsMapping( + const std::vector& dims_mapping) { + std::map split_axis_to_mesh_axis; + for (size_t i = 0; i < dims_mapping.size(); ++i) { + if (dims_mapping[i] != -1) { + split_axis_to_mesh_axis.emplace(i, dims_mapping[i]); + } + } + return split_axis_to_mesh_axis; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..dceaa5150a6b0a4835939536c5e86882d44e3dca --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h @@ -0,0 +1,50 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace phi { +namespace distributed { +namespace auto_parallel { + +class ProcessMesh; +} // namespace auto_parallel + +using auto_parallel::ProcessMesh; + +bool IsDimsMappingShard(const std::vector& dims_mapping); + +bool IsDimsMappingReplicated(const std::vector& dims_mapping); + +int64_t GetCurGlobalRank(); + +// Get the coordinate of cur rank in process mesh. For example, the process mesh +// is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will +// return [2, 0]; if the current rank is 3, then will return [1, 1]. +std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh); + +// If the index i's value in dims_mapping is x ( x != -1), means the ith axis of +// tensor need be split by xth axis of process_mesh. The function analyze the +// input vector, return a key-value map of tensor_split_axis and +// process_mesh_split_axis. +// For example, if dims_mapping is [-1, 1, -1, 0], will return {1: 1, 3: 0}. +std::map GetSplitAxisWithDimsMapping( + const std::vector& dims_mapping); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/kernels/cpu/split_kernel.cc b/paddle/phi/kernels/cpu/split_kernel.cc index f277e0c39f375371c428ee4b96978c01662c465c..13ac7eed3d5774da3304e8f1c2f16db23845608e 100644 --- a/paddle/phi/kernels/cpu/split_kernel.cc +++ b/paddle/phi/kernels/cpu/split_kernel.cc @@ -29,8 +29,11 @@ PD_REGISTER_KERNEL(split, bool, uint8_t, int8_t, + int16_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_KERNEL(split_with_num, CPU, diff --git a/paddle/phi/kernels/gpu/split_kernel.cu b/paddle/phi/kernels/gpu/split_kernel.cu index 133734621360dbb8799ede75cbc97388d1425828..ea140b54eb170fb6204b4e88c0fe08a882f1d61d 100644 --- a/paddle/phi/kernels/gpu/split_kernel.cu +++ b/paddle/phi/kernels/gpu/split_kernel.cu @@ -29,8 +29,11 @@ PD_REGISTER_KERNEL(split, bool, uint8_t, int8_t, + int16_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_KERNEL(split_with_num, GPU, diff --git a/test/cpp/auto_parallel/CMakeLists.txt b/test/cpp/auto_parallel/CMakeLists.txt index c5912a6fa102101d55ea912865a9bf0794ac34ec..6e0ea8db1e0452c60783457c02cb85e52397fdcf 100644 --- a/test/cpp/auto_parallel/CMakeLists.txt +++ b/test/cpp/auto_parallel/CMakeLists.txt @@ -9,6 +9,11 @@ if(WITH_DISTRIBUTE) dist_tensor_test SRCS dist_tensor_test.cc DEPS phi) + + cc_test( + test_reshard_r_to_s + SRCS test_reshard_r_to_s.cc + DEPS phi) endif() cc_test_old(dist_mapper_test SRCS dist_mapper_test.cc DEPS phi) diff --git a/test/cpp/auto_parallel/test_reshard_r_to_s.cc b/test/cpp/auto_parallel/test_reshard_r_to_s.cc new file mode 100644 index 0000000000000000000000000000000000000000..03bd8d247781a5fa4ca4d0e789055f4fd9603150 --- /dev/null +++ b/test/cpp/auto_parallel/test_reshard_r_to_s.cc @@ -0,0 +1,208 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/tensor_utils.h" + +namespace phi { +namespace distributed { +namespace auto_parallel { +namespace tests { + +std::shared_ptr ConstructReplicatedDistCPU( + phi::CPUContext* dev_ctx, + const std::vector& shape, + const ProcessMesh& mesh) { + phi::CPUPlace cpu_place = dev_ctx->GetPlace(); + const DDim dims(shape.data(), shape.size()); + + int64_t num_of_elems = 1; + for (const auto& value : shape) { + num_of_elems *= value; + } + + phi::DenseTensor input_dense; + float* input_dense_ptr = input_dense.mutable_data(dims, cpu_place); + + std::vector vec(num_of_elems); + memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float)); + + std::shared_ptr dist_attr = + std::make_shared(shape); + + std::vector dims_mapping(shape.size(), -1); + dist_attr->set_dims_mapping(dims_mapping); + dist_attr->set_process_mesh(mesh); + + return std::make_shared( + std::make_shared(input_dense), dist_attr); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +std::shared_ptr ConstructReplicatedDistGPU( + phi::GPUContext* dev_ctx, + const std::vector& shape, + const ProcessMesh& mesh) { + phi::GPUPlace gpu_place = dev_ctx->GetPlace(); + phi::CPUPlace cpu_place; + const DDim dims(shape.data(), shape.size()); + + int64_t num_of_elems = 1; + for (const auto& value : shape) { + num_of_elems *= value; + } + + phi::DenseTensor input_dense; + phi::DenseTensor input_dense_gpu; + float* input_dense_ptr = input_dense.mutable_data(dims, cpu_place); + + std::vector vec(num_of_elems); + memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float)); + phi::Copy(*dev_ctx, input_dense, gpu_place, true, &input_dense_gpu); + + std::shared_ptr dist_attr = + std::make_shared(shape); + + std::vector dims_mapping(shape.size(), -1); + dist_attr->set_dims_mapping(dims_mapping); + dist_attr->set_process_mesh(mesh); + + return std::make_shared( + std::make_shared(input_dense_gpu), dist_attr); +} +#endif + +TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) { + setenv("PADDLE_TRAINER_ID", "1", 1); + + std::vector tensor_shape = {6, 8}; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); + + std::vector mesh_shape = {4}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + + std::shared_ptr input = + ConstructReplicatedDistCPU(context, tensor_shape, mesh); + + std::shared_ptr out_dist_attr = + std::make_shared(tensor_shape); + std::vector out_dims_mapping = {-1, 0}; + out_dist_attr->set_dims_mapping(out_dims_mapping); + out_dist_attr->set_process_mesh(mesh); + + RToSReshardFunction r_to_s_func; + std::shared_ptr output = + r_to_s_func.Eval(*context, *input, out_dist_attr); + + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); + CHECK_EQ(output->numel(), 12); + CHECK_EQ(output->dims(), DDim({6, 2})); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) { + setenv("PADDLE_TRAINER_ID", "0", 0); + + std::vector tensor_shape = {6, 8, 4}; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* context = reinterpret_cast(pool.Get(phi::GPUPlace())); + + std::vector mesh_shape = {6}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + + std::shared_ptr out_dist_attr = + std::make_shared(tensor_shape); + std::vector out_dims_mapping = {0, -1}; + out_dist_attr->set_dims_mapping(out_dims_mapping); + out_dist_attr->set_process_mesh(mesh); + + std::shared_ptr input = + ConstructReplicatedDistGPU(context, tensor_shape, mesh); + + RToSReshardFunction r_to_s_func; + std::shared_ptr output = + r_to_s_func.Eval(*context, *input, out_dist_attr); + + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); + CHECK_EQ(output->numel(), 32); + CHECK_EQ(output->dims(), DDim({1, 8, 4})); +} +#endif + +TEST(reshard_r_to_s, r_to_s_diff_placement) { + std::vector tensor_shape = {6, 8}; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); + + std::vector mesh_shape = {4}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + + std::shared_ptr input = + ConstructReplicatedDistCPU(context, tensor_shape, mesh); + + std::vector out_process_ids = {2, 3, 4, 5}; + ProcessMesh out_mesh(mesh_shape, out_process_ids, dim_names); + std::shared_ptr out_dist_attr = + std::make_shared(tensor_shape); + std::vector out_dims_mapping = {-1, 0}; + out_dist_attr->set_dims_mapping(out_dims_mapping); + out_dist_attr->set_process_mesh(out_mesh); + + RToSReshardFunction r_to_s_func; + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false); +} + +TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { + std::vector tensor_shape = {6, 12}; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); + + std::vector mesh_shape = {4, 2}; + std::vector process_ids = {0, 1, 2, 3, 4, 5, 6, 7}; + std::vector dim_names = {"x", "y"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + + std::shared_ptr input = + ConstructReplicatedDistCPU(context, tensor_shape, mesh); + + std::shared_ptr out_dist_attr = + std::make_shared(tensor_shape); + std::vector out_dims_mapping = {1, 0}; + out_dist_attr->set_dims_mapping(out_dims_mapping); + out_dist_attr->set_process_mesh(mesh); + + RToSReshardFunction r_to_s_func; + + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false); +} + +} // namespace tests +} // namespace auto_parallel +} // namespace distributed +} // namespace phi