From 99795a138cb1b35dca1d48becfc37aeccf2d731d Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Fri, 25 Aug 2023 15:55:36 +0800 Subject: [PATCH] [Reshard] Support create shard tensor and non-zero dim reshard (#56553) * support create shard dist tesnor * support non-zero shard to replicated * change reshard signature --- paddle/fluid/pybind/eager_properties.cc | 25 ++++++++ .../distributed/auto_parallel/CMakeLists.txt | 1 + .../distributed/auto_parallel/dist_tensor.cc | 24 ++++--- .../distributed/auto_parallel/dist_tensor.h | 17 +++-- .../auto_parallel/r_to_s_reshard_function.cc | 22 +++---- .../auto_parallel/r_to_s_reshard_function.h | 8 +-- .../auto_parallel/reshard_concat_functor.cc | 55 ++++++++++++++++ .../auto_parallel/reshard_concat_functor.h | 30 +++++++++ .../auto_parallel/reshard_function.cc | 47 +++++++++++++- .../auto_parallel/reshard_function.h | 35 +++++++++-- .../auto_parallel/s_to_r_reshard_function.cc | 62 +++++++++++++++---- .../auto_parallel/s_to_r_reshard_function.h | 8 +-- paddle/phi/kernels/concat_kernel.h | 20 ++++-- paddle/phi/kernels/cpu/concat_grad_kernel.cc | 2 + paddle/phi/kernels/cpu/concat_kernel.cc | 1 + paddle/phi/kernels/gpu/concat_grad_kernel.cu | 2 + paddle/phi/kernels/gpu/concat_kernel.cu | 1 + paddle/phi/kernels/split_kernel.h | 2 +- test/auto_parallel/reshard_r_to_s.py | 3 +- test/auto_parallel/reshard_s_to_r.py | 5 +- test/auto_parallel/test_dist_tensor.py | 2 +- test/auto_parallel/test_reshard_s_to_r.py | 2 +- test/auto_parallel/test_shard_tensor_api.py | 12 ++-- test/cpp/auto_parallel/dist_tensor_test.cc | 14 ++--- 24 files changed, 316 insertions(+), 84 deletions(-) create mode 100644 paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.cc create mode 100644 paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h diff --git a/paddle/fluid/pybind/eager_properties.cc b/paddle/fluid/pybind/eager_properties.cc index 1da70fda183..5ada60c81da 100644 --- a/paddle/fluid/pybind/eager_properties.cc +++ b/paddle/fluid/pybind/eager_properties.cc @@ -388,6 +388,26 @@ PyObject* tensor_properties_get_dist_attr(TensorObject* self, void* closure) { EAGER_CATCH_AND_THROW_RETURN_NULL } +PyObject* tensor_properties_get_local_shape(TensorObject* self, void* closure) { + EAGER_TRY + if (self->tensor.is_dist_tensor()) { +#ifdef PADDLE_WITH_DISTRIBUTE + phi::distributed::DistTensor* dist_tensor = + static_cast(self->tensor.impl().get()); + return ToPyObject(phi::vectorize(dist_tensor->local_dims())); +#else + PADDLE_THROW(platform::errors::Unavailable( + "The `_local_shape` property of (Dist)Tensor is not supported " + "in the current PaddlePaddle, please recompile and install " + "PaddlePaddle " + "with the option of `WITH_DISTRIBUTE=ON`.")); +#endif + } else { + RETURN_PY_NONE + } + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyDoc_STRVAR(tensor_shape__doc__, R"DOC(shape @@ -716,6 +736,11 @@ struct PyGetSetDef variable_properties[] = { // NOLINT (setter)tensor_properties_set_persistable, tensor_persistable__doc__, nullptr}, + {"_local_shape", + (getter)tensor_properties_get_local_shape, + nullptr, + nullptr, + nullptr}, {"shape", (getter)tensor_properties_get_shape, nullptr, diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index 031b3033458..91cbe4a3ff4 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE) dist_tensor.cc reshard_function.cc reshard_split_functor.cc + reshard_concat_functor.cc reshard_all_gather_functor.cc r_to_s_reshard_function.cc s_to_r_reshard_function.cc) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index a4aaa6d3027..c762eeb1f9f 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -14,6 +14,10 @@ #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" + namespace phi { namespace distributed { @@ -27,15 +31,21 @@ inline void check_defined(const DistTensor& dist_tensor, method_hint)); } -// TODO(chenweihang): Reshard the input global value into local value DistTensor::DistTensor(const phi::DenseTensor& global_value, const TensorDistAttr& dist_attr) - : dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {} - -DistTensor::DistTensor(const phi::DenseTensor& value, - const DDim& dims, - const TensorDistAttr& dist_attr) - : dims_(dims), dist_attr_(dist_attr), value_(value) {} + : dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) { + if (!IsDimsMappingReplicated(dist_attr_.dims_mapping())) { + // 1. create replicated global tensor + int64_t dims_size = global_value.dims().size(); + std::vector dims_mapping(dims_size, -1); + dist_attr_.set_dims_mapping(dims_mapping); + + // 2. reshard from replicated to other state + auto* func = ChooseProperReshardFunction(*this, dist_attr); + auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place()); + func->Eval(dev_ctx, *this, dist_attr, this); + } +} DistTensor::DistTensor(const DDim& dims, const TensorDistAttr& dist_attr) : dims_(dims), dist_attr_(dist_attr) {} diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index 517c35cd731..e3a738b2ba1 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -21,26 +21,23 @@ namespace phi { namespace distributed { +class ReshardFunction; class DistTensor final : public phi::TensorBase, public phi::TypeInfoTraits { public: + /// \brief Careful to create dist tensor using default constructor. + /// this should only used in reshard for now, and the dist properties + /// will be set by reshard later. + DistTensor() = default; + /// \brief Construct a dist tensor based dense tensor. /// \param global_value The global dense tensor of the current tensor. /// \param dist_attr The distributed attributes of the current tensor. DistTensor(const phi::DenseTensor& global_value, const TensorDistAttr& dist_attr); - // TODO(chenweihang): Remove this constructor after added reshard impl - /// \brief Construct a dist tensor based dense tensor. - /// \param value The local dense tensor of the current tensor. - /// \param dims The global dimension of the currnet tensor. - /// \param dist_attr The distributed attributes of the current tensor. - DistTensor(const phi::DenseTensor& value, - const DDim& dims, - const TensorDistAttr& dist_attr); - /// \brief Construct a empty dist tensor (for infer spmd) /// \param dims The global dimension of the currnet Tensor. /// \param dist_attr The distributed attributes of the current tensor. @@ -109,6 +106,8 @@ class DistTensor final bool fake_alloc = false) override; private: + friend class ReshardFunction; + // The global dimensions(shape) DDim dims_; // The distributed attributes diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc index e8ffe9e9318..3a60e226793 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc @@ -46,10 +46,10 @@ bool RToSReshardFunction::IsSuitable(const DistTensor& in, return flag; } -std::shared_ptr RToSReshardFunction::Eval( - phi::DeviceContext* dev_ctx, - const DistTensor& in, - const TensorDistAttr& out_dist_attr) { +void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { const auto& out_dims_mapping = out_dist_attr.dims_mapping(); const auto& out_process_mesh = out_dist_attr.process_mesh(); const DenseTensor& in_physical_tensor_cur_rank = in.value(); @@ -63,14 +63,6 @@ std::shared_ptr RToSReshardFunction::Eval( int64_t split_axis = split_axis_to_mesh_axis.begin()->first; int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second; - PADDLE_ENFORCE_LT( - mesh_axis, - out_process_mesh.ndim(), - phi::errors::OutOfRange( - "The mesh axis %lld exceed the size of process mesh %lld.", - mesh_axis, - out_process_mesh.ndim())); - int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis << ". Split will use axis " << mesh_axis << " of process_mesh." @@ -86,13 +78,15 @@ std::shared_ptr RToSReshardFunction::Eval( VLOG(3) << "The current process will remain the idx " << coord_in_mesh[mesh_axis] << " piece of tensor"; + out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]]; VLOG(3) << "The shape of physical tensor after split is " << out_physical_tensor_cur_rank.dims(); - return std::make_shared( - out_physical_tensor_cur_rank, in.dims(), out_dist_attr); + set_dist_props(out, out_physical_tensor_cur_rank, in.dims(), out_dist_attr); } +REGISTER_RESHARD_FUNC(RToSReshardFunction); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h index 47bdd8c5d19..3a86ff0cfa0 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h @@ -27,10 +27,10 @@ class RToSReshardFunction final : public ReshardFunction { bool IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) override; - std::shared_ptr Eval( - DeviceContext* dev_ctx, - const DistTensor& in, - const TensorDistAttr& out_dist_attr) override; + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; }; } // namespace distributed diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.cc b/paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.cc new file mode 100644 index 00000000000..49115dbffd0 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/concat_kernel.h" + +namespace phi { +namespace distributed { + +DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx, + const std::vector& input, + int64_t axis) { + DenseTensor result; + auto dtype = (*input.begin())->dtype(); + + if (phi::CPUContext::classof(&dev_ctx)) { + PD_VISIT_ALL_TYPES( + dtype, "Concat", ([&] { + Concat( + static_cast(dev_ctx), input, axis, &result); + })); + return result; + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (phi::GPUContext::classof(&dev_ctx)) { + PD_VISIT_ALL_TYPES( + dtype, "Concat", ([&] { + Concat( + static_cast(dev_ctx), input, axis, &result); + })); + return result; + } +#endif + PADDLE_THROW(phi::errors::Unimplemented( + "The concat in reshard only supported on CPU and GPU for now.")); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h b/paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h new file mode 100644 index 00000000000..ce4798458bd --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace phi { +class DeviceContext; +class DenseTensor; +namespace distributed { + +DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx, + const std::vector& input, + int64_t axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc index 04bbc4a09fe..637af9641d3 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc @@ -13,10 +13,53 @@ // limitations under the License. #include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" -#include "paddle/phi/core/device_context.h" + #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" namespace phi { -namespace distributed {} // namespace distributed +namespace distributed { + +std::shared_ptr ReshardFunction::Eval( + DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr) { + std::shared_ptr out = std::make_shared(); + Eval(dev_ctx, in, out_dist_attr, out.get()); + return out; +} + +void ReshardFunction::set_dist_props(DistTensor* tensor, + const DenseTensor& value, + const DDim& dims, + const TensorDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ(dist_attr.verify(vectorize(dims)), + true, + phi::errors::InvalidArgument( + "The input dist_attr and dims are improper.")); + + tensor->value_ = value; + tensor->dims_ = dims; + tensor->dist_attr_ = dist_attr; +} + +ReshardFunction* ChooseProperReshardFunction( + const DistTensor& in, const TensorDistAttr& out_dist_attr) { + for (const auto& func : GetReshardFunctionList()) { + if (func->IsSuitable(in, out_dist_attr)) { + return func.get(); + } + } + PADDLE_THROW(phi::errors::Unimplemented( + "Can not reshard from in_dist_attr=%s to out_dist_attr=%s.", + in.dist_attr().to_string(), + out_dist_attr.to_string())); +} + +std::vector>& GetReshardFunctionList() { + static std::vector> func_list; + return func_list; +} + +} // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard_function.h index d34b7cb8040..305a9af337c 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.h @@ -14,6 +14,9 @@ #pragma once #include +#include + +#include "paddle/phi/core/dense_tensor.h" namespace phi { class DeviceContext; @@ -31,11 +34,35 @@ class ReshardFunction { virtual bool IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) = 0; - virtual std::shared_ptr Eval( - DeviceContext* dev_ctx, - const DistTensor& in, - const TensorDistAttr& out_dist_attr) = 0; + std::shared_ptr Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr); + + virtual void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) = 0; + + protected: + void set_dist_props(DistTensor* tensor, + const DenseTensor& value, + const DDim& dims, + const TensorDistAttr& dist_attr); }; +std::vector>& GetReshardFunctionList(); + +#define REGISTER_RESHARD_FUNC(func_type) \ + class __RegisterReshard_##func_type { \ + public: \ + __RegisterReshard_##func_type() { \ + GetReshardFunctionList().emplace_back(std::make_unique()); \ + } \ + }; \ + static __RegisterReshard_##func_type local_reshard_func_##func_type + +ReshardFunction* ChooseProperReshardFunction( + const DistTensor& in, const TensorDistAttr& out_dist_attr); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc index b99b4bbbfc3..e10587237ba 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc @@ -15,12 +15,13 @@ #include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" #include "glog/logging.h" +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" -#include "paddle/phi/core/distributed/comm_context_manager.h" -#include "paddle/phi/core/distributed/store/tcp_store.h" namespace phi { namespace distributed { @@ -43,18 +44,25 @@ bool SToRReshardFunction::IsSuitable(const DistTensor& in, flag &= (out_process_mesh.ndim() == 1); flag &= (in_process_mesh == out_process_mesh); + // Ensure the tensor is balanced split, or we need send/recv rather than + // all_gather + std::map split_axis_to_mesh_axis = + GetSplitAxisWithDimsMapping(in_dims_mapping); + int64_t split_axis = split_axis_to_mesh_axis.begin()->first; + int64_t num_of_process = in_process_mesh.size(); + flag &= + (in.local_dims()[split_axis] * num_of_process == in.dims()[split_axis]); + return flag; } -std::shared_ptr SToRReshardFunction::Eval( - DeviceContext* dev_ctx, - const DistTensor& in, - const TensorDistAttr& out_dist_attr) { - // TODO(liyurui): Only support transfer shard(0) to replicate for now. - // Concat is needed when transfer shard(x) to replicate, will be supported - // later. +void SToRReshardFunction::Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { const DenseTensor& in_physical_tensor_cur_rank = in.value(); const auto& in_dist_attr = in.dist_attr(); + const auto& in_dims_mapping = in_dist_attr.dims_mapping(); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); @@ -64,9 +72,41 @@ std::shared_ptr SToRReshardFunction::Eval( DenseTensor out_all_gather = ReshardAllGatherFunctor( dev_ctx, in_physical_tensor_cur_rank, in_process_ids); - return std::make_shared( - out_all_gather, out_all_gather.dims(), out_dist_attr); + std::map split_axis_to_mesh_axis = + GetSplitAxisWithDimsMapping(in_dims_mapping); + int64_t split_axis = split_axis_to_mesh_axis.begin()->first; + + if (split_axis == 0) { + // If the input dist tensor is shard(0), the subsequent split + // and concat is unnecessary. + set_dist_props(out, out_all_gather, out_all_gather.dims(), out_dist_attr); + } else { + // Since the result of all_gather always concat the tensor on axis 0, + // first we need to split the result on axis 0, + // then we need to concat the split result on input split axis. + int64_t default_split_axis = 0; + int64_t num_of_process = in_process_ids.size(); + + IntArray sections(std::vector( + num_of_process, + in_physical_tensor_cur_rank.dims()[default_split_axis])); + std::vector split_out_vec = ReshardSplitFunctor( + *dev_ctx, out_all_gather, sections, default_split_axis); + + // Concat the result after split on correct axis. + std::vector concat_input_vec; + for (const auto& tensor : split_out_vec) { + concat_input_vec.emplace_back(&tensor); + } + DenseTensor concat_out_tensor = + ReshardConcatFunctor(*dev_ctx, concat_input_vec, split_axis); + + set_dist_props( + out, concat_out_tensor, concat_out_tensor.dims(), out_dist_attr); + } } +REGISTER_RESHARD_FUNC(SToRReshardFunction); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h index aa97f167f49..869b4ed9178 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h @@ -26,10 +26,10 @@ class SToRReshardFunction final : public ReshardFunction { bool IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) override; - std::shared_ptr Eval( - DeviceContext* dev_ctx, - const DistTensor& in, - const TensorDistAttr& out_dist_attr) override; + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; }; } // namespace distributed diff --git a/paddle/phi/kernels/concat_kernel.h b/paddle/phi/kernels/concat_kernel.h index f5ac2d3cbb7..d3b99449a06 100644 --- a/paddle/phi/kernels/concat_kernel.h +++ b/paddle/phi/kernels/concat_kernel.h @@ -27,9 +27,10 @@ void ConcatKernel(const Context& dev_ctx, DenseTensor* out); template -DenseTensor Concat(const Context& dev_ctx, - const std::vector& x, - const Scalar& axis) { +void Concat(const Context& dev_ctx, + const std::vector& x, + const Scalar& axis, + DenseTensor* dense_out) { std::vector meta_x; meta_x.reserve(x.size()); std::vector meta_x_ptr; @@ -38,10 +39,17 @@ DenseTensor Concat(const Context& dev_ctx, meta_x_ptr.push_back(&meta_x.back()); } - DenseTensor dense_out; - MetaTensor meta_out(&dense_out); + MetaTensor meta_out(dense_out); ConcatInferMeta(meta_x_ptr, axis.to(), &meta_out); - ConcatKernel(dev_ctx, x, axis, &dense_out); + ConcatKernel(dev_ctx, x, axis, dense_out); +} + +template +DenseTensor Concat(const Context& dev_ctx, + const std::vector& x, + const Scalar& axis) { + DenseTensor dense_out; + Concat(dev_ctx, x, axis, &dense_out); return dense_out; } diff --git a/paddle/phi/kernels/cpu/concat_grad_kernel.cc b/paddle/phi/kernels/cpu/concat_grad_kernel.cc index 56ed95769fe..b7e33ac21f9 100644 --- a/paddle/phi/kernels/cpu/concat_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/concat_grad_kernel.cc @@ -29,6 +29,8 @@ PD_REGISTER_KERNEL(concat_grad, bool, int64_t, int, + int8_t, + int16_t, uint8_t, phi::dtype::float16, phi::dtype::complex, diff --git a/paddle/phi/kernels/cpu/concat_kernel.cc b/paddle/phi/kernels/cpu/concat_kernel.cc index 770a6fbd2df..8ff3a4bff35 100644 --- a/paddle/phi/kernels/cpu/concat_kernel.cc +++ b/paddle/phi/kernels/cpu/concat_kernel.cc @@ -125,6 +125,7 @@ PD_REGISTER_KERNEL(concat, int, uint8_t, int8_t, + int16_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/gpu/concat_grad_kernel.cu b/paddle/phi/kernels/gpu/concat_grad_kernel.cu index 2445978daca..177c65bc9b6 100644 --- a/paddle/phi/kernels/gpu/concat_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/concat_grad_kernel.cu @@ -31,6 +31,8 @@ PD_REGISTER_KERNEL(concat_grad, int64_t, int, uint8_t, + int8_t, + int16_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/gpu/concat_kernel.cu b/paddle/phi/kernels/gpu/concat_kernel.cu index 47e5a220e66..f0dc0c91534 100644 --- a/paddle/phi/kernels/gpu/concat_kernel.cu +++ b/paddle/phi/kernels/gpu/concat_kernel.cu @@ -122,6 +122,7 @@ PD_REGISTER_KERNEL(concat, int, uint8_t, int8_t, + int16_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/split_kernel.h b/paddle/phi/kernels/split_kernel.h index 7a6b7173961..2869bf3206f 100644 --- a/paddle/phi/kernels/split_kernel.h +++ b/paddle/phi/kernels/split_kernel.h @@ -85,7 +85,7 @@ std::vector Split(const Context& dev_ctx, size_t out_number = sections.GetData().size(); std::vector result(out_number); - Split(dev_ctx, x, sections, axis, &result); + Split(dev_ctx, x, sections, axis, &result); return result; } diff --git a/test/auto_parallel/reshard_r_to_s.py b/test/auto_parallel/reshard_r_to_s.py index e52ea4af39c..814b0ef0dd7 100644 --- a/test/auto_parallel/reshard_r_to_s.py +++ b/test/auto_parallel/reshard_r_to_s.py @@ -68,7 +68,8 @@ class TestReshardRToS: else out_shape[self._shard] // 2 + 1 ) - assert np.equal(out.numpy().shape, out_shape).all() + assert np.equal(out.shape, input_tensor.shape).all() + assert np.equal(out._local_shape, out_shape).all() if __name__ == '__main__': diff --git a/test/auto_parallel/reshard_s_to_r.py b/test/auto_parallel/reshard_s_to_r.py index a6aa1be5b90..90ba0cc655f 100644 --- a/test/auto_parallel/reshard_s_to_r.py +++ b/test/auto_parallel/reshard_s_to_r.py @@ -57,9 +57,8 @@ class TestReshardSToR: assert reshard_func.is_suitable(input_tensor, out_dist_attr) out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) - out_shape = list(self._shape) - out_shape[self._shard] = out_shape[self._shard] * 2 - assert np.equal(out.shape, out_shape).all() + assert np.equal(out.shape, out._local_shape).all() + assert np.equal(out.shape, input_tensor.shape).all() if __name__ == '__main__': diff --git a/test/auto_parallel/test_dist_tensor.py b/test/auto_parallel/test_dist_tensor.py index 55db063311c..45aa8c9fbca 100644 --- a/test/auto_parallel/test_dist_tensor.py +++ b/test/auto_parallel/test_dist_tensor.py @@ -25,7 +25,7 @@ class TestDistTensor(unittest.TestCase): def test_dist_tensor_creation(self): shape = [10, 5] mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) - dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) + dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None]) # create dist tensor using numpy dist_tensor_with_numpy = dist.shard_tensor( diff --git a/test/auto_parallel/test_reshard_s_to_r.py b/test/auto_parallel/test_reshard_s_to_r.py index ca0c3edf1f1..fd67df648a9 100644 --- a/test/auto_parallel/test_reshard_s_to_r.py +++ b/test/auto_parallel/test_reshard_s_to_r.py @@ -24,9 +24,9 @@ class TestReshardSToR(test_base.CommunicationTestDistBase): "shape": "(10, 20)", "dtype": "float32", "seeds": str(self._seeds), - "shard": "0", } self._changeable_envs = { + "shard": ["0", "1"], "backend": ["cpu", "gpu"], } diff --git a/test/auto_parallel/test_shard_tensor_api.py b/test/auto_parallel/test_shard_tensor_api.py index 764cbdc36e2..af96c715131 100644 --- a/test/auto_parallel/test_shard_tensor_api.py +++ b/test/auto_parallel/test_shard_tensor_api.py @@ -27,7 +27,7 @@ class TestDistAttrBasic(unittest.TestCase): exception = None try: mesh = [[0, 1], [2, 3]] - dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) + dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None]) except ValueError as ex: self.assertIn( "The mesh must be an instance of paddle.distributed.ProcessMesh", @@ -44,7 +44,7 @@ class TestDistAttrBasic(unittest.TestCase): [[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"] ) dist_attr = dist.DistAttr( - mesh=mesh, sharding_specs={"x": 0, "y": 1} + mesh=mesh, sharding_specs={"x": None, "y": None} ) except ValueError as ex: self.assertIn( @@ -63,7 +63,7 @@ class TestShardTensorDynamic(unittest.TestCase): def test_dynamic(self): dist_attr = dist.DistAttr( - mesh=self.mesh, sharding_specs=['x', None, None] + mesh=self.mesh, sharding_specs=[None, None, None] ) input = paddle.rand([4, 1024, 512]) @@ -71,7 +71,7 @@ class TestShardTensorDynamic(unittest.TestCase): print(dist_attr.dims_mapping) self.assertEqual(d_tensor.dist_attr.process_mesh, self.mesh) - self.assertEqual(d_tensor.dist_attr.dims_mapping, [0, -1, -1]) + self.assertEqual(d_tensor.dist_attr.dims_mapping, [-1, -1, -1]) self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh")) self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping")) @@ -111,7 +111,7 @@ class TestShardTensorStaticDy2Static(unittest.TestCase): [[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"] ) dist_attr = dist.DistAttr( - mesh=mesh, sharding_specs=['x', None, None] + mesh=mesh, sharding_specs=[None, None, None] ) input = paddle.rand([4, 1024, 512]) @@ -126,7 +126,7 @@ class TestShardTensorStaticDy2Static(unittest.TestCase): static_tensor ) self.assertEqual(dist_input.dist_attr.process_mesh, mesh) - self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1]) + self.assertEqual(dist_input.dist_attr.dims_mapping, [-1, -1, -1]) self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh")) self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping")) diff --git a/test/cpp/auto_parallel/dist_tensor_test.cc b/test/cpp/auto_parallel/dist_tensor_test.cc index ab9bdc51477..c190c0e7b17 100644 --- a/test/cpp/auto_parallel/dist_tensor_test.cc +++ b/test/cpp/auto_parallel/dist_tensor_test.cc @@ -46,20 +46,14 @@ TEST(dist_tensor, constructor) { EXPECT_EQ(dist_x1.local_dims()[0], 3L); EXPECT_EQ(dist_x1.local_dims()[1], 4L); - DenseTensor x2(alloc, meta); - DistTensor dist_x2(x2, dims, dist_attr); - EXPECT_TRUE(dist_x2.defined()); - EXPECT_TRUE(dist_x2.initialized()); - EXPECT_TRUE(dist_x1.valid()); - // empty construct - DistTensor dist_x3(dims, dist_attr); - EXPECT_TRUE(!dist_x3.defined()); - EXPECT_TRUE(!dist_x3.initialized()); + DistTensor dist_x2(dims, dist_attr); + EXPECT_TRUE(!dist_x2.defined()); + EXPECT_TRUE(!dist_x2.initialized()); // allocate error test bool caught_exception = false; try { - dist_x3.AllocateFrom(alloc, phi::DataType::FLOAT32, 12L, false); + dist_x2.AllocateFrom(alloc, phi::DataType::FLOAT32, 12L, false); } catch (phi::EnforceNotMet& error) { caught_exception = true; EXPECT_NE(std::string(error.what()).find("Unavailable"), 0UL); -- GitLab