未验证 提交 99795a13 编写于 作者: L LiYuRio 提交者: GitHub

[Reshard] Support create shard tensor and non-zero dim reshard (#56553)

* support create shard dist tesnor

* support non-zero shard to replicated

* change reshard signature
上级 d3f4596a
...@@ -388,6 +388,26 @@ PyObject* tensor_properties_get_dist_attr(TensorObject* self, void* closure) { ...@@ -388,6 +388,26 @@ PyObject* tensor_properties_get_dist_attr(TensorObject* self, void* closure) {
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
PyObject* tensor_properties_get_local_shape(TensorObject* self, void* closure) {
EAGER_TRY
if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
return ToPyObject(phi::vectorize<int64_t>(dist_tensor->local_dims()));
#else
PADDLE_THROW(platform::errors::Unavailable(
"The `_local_shape` property of (Dist)Tensor is not supported "
"in the current PaddlePaddle, please recompile and install "
"PaddlePaddle "
"with the option of `WITH_DISTRIBUTE=ON`."));
#endif
} else {
RETURN_PY_NONE
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyDoc_STRVAR(tensor_shape__doc__, PyDoc_STRVAR(tensor_shape__doc__,
R"DOC(shape R"DOC(shape
...@@ -716,6 +736,11 @@ struct PyGetSetDef variable_properties[] = { // NOLINT ...@@ -716,6 +736,11 @@ struct PyGetSetDef variable_properties[] = { // NOLINT
(setter)tensor_properties_set_persistable, (setter)tensor_properties_set_persistable,
tensor_persistable__doc__, tensor_persistable__doc__,
nullptr}, nullptr},
{"_local_shape",
(getter)tensor_properties_get_local_shape,
nullptr,
nullptr,
nullptr},
{"shape", {"shape",
(getter)tensor_properties_get_shape, (getter)tensor_properties_get_shape,
nullptr, nullptr,
......
...@@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE) ...@@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE)
dist_tensor.cc dist_tensor.cc
reshard_function.cc reshard_function.cc
reshard_split_functor.cc reshard_split_functor.cc
reshard_concat_functor.cc
reshard_all_gather_functor.cc reshard_all_gather_functor.cc
r_to_s_reshard_function.cc r_to_s_reshard_function.cc
s_to_r_reshard_function.cc) s_to_r_reshard_function.cc)
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
namespace phi { namespace phi {
namespace distributed { namespace distributed {
...@@ -27,15 +31,21 @@ inline void check_defined(const DistTensor& dist_tensor, ...@@ -27,15 +31,21 @@ inline void check_defined(const DistTensor& dist_tensor,
method_hint)); method_hint));
} }
// TODO(chenweihang): Reshard the input global value into local value
DistTensor::DistTensor(const phi::DenseTensor& global_value, DistTensor::DistTensor(const phi::DenseTensor& global_value,
const TensorDistAttr& dist_attr) const TensorDistAttr& dist_attr)
: dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {} : dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {
if (!IsDimsMappingReplicated(dist_attr_.dims_mapping())) {
DistTensor::DistTensor(const phi::DenseTensor& value, // 1. create replicated global tensor
const DDim& dims, int64_t dims_size = global_value.dims().size();
const TensorDistAttr& dist_attr) std::vector<int64_t> dims_mapping(dims_size, -1);
: dims_(dims), dist_attr_(dist_attr), value_(value) {} dist_attr_.set_dims_mapping(dims_mapping);
// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(*this, dist_attr);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place());
func->Eval(dev_ctx, *this, dist_attr, this);
}
}
DistTensor::DistTensor(const DDim& dims, const TensorDistAttr& dist_attr) DistTensor::DistTensor(const DDim& dims, const TensorDistAttr& dist_attr)
: dims_(dims), dist_attr_(dist_attr) {} : dims_(dims), dist_attr_(dist_attr) {}
......
...@@ -21,26 +21,23 @@ ...@@ -21,26 +21,23 @@
namespace phi { namespace phi {
namespace distributed { namespace distributed {
class ReshardFunction;
class DistTensor final class DistTensor final
: public phi::TensorBase, : public phi::TensorBase,
public phi::TypeInfoTraits<phi::TensorBase, DistTensor> { public phi::TypeInfoTraits<phi::TensorBase, DistTensor> {
public: public:
/// \brief Careful to create dist tensor using default constructor.
/// this should only used in reshard for now, and the dist properties
/// will be set by reshard later.
DistTensor() = default;
/// \brief Construct a dist tensor based dense tensor. /// \brief Construct a dist tensor based dense tensor.
/// \param global_value The global dense tensor of the current tensor. /// \param global_value The global dense tensor of the current tensor.
/// \param dist_attr The distributed attributes of the current tensor. /// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const phi::DenseTensor& global_value, DistTensor(const phi::DenseTensor& global_value,
const TensorDistAttr& dist_attr); const TensorDistAttr& dist_attr);
// TODO(chenweihang): Remove this constructor after added reshard impl
/// \brief Construct a dist tensor based dense tensor.
/// \param value The local dense tensor of the current tensor.
/// \param dims The global dimension of the currnet tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const phi::DenseTensor& value,
const DDim& dims,
const TensorDistAttr& dist_attr);
/// \brief Construct a empty dist tensor (for infer spmd) /// \brief Construct a empty dist tensor (for infer spmd)
/// \param dims The global dimension of the currnet Tensor. /// \param dims The global dimension of the currnet Tensor.
/// \param dist_attr The distributed attributes of the current tensor. /// \param dist_attr The distributed attributes of the current tensor.
...@@ -109,6 +106,8 @@ class DistTensor final ...@@ -109,6 +106,8 @@ class DistTensor final
bool fake_alloc = false) override; bool fake_alloc = false) override;
private: private:
friend class ReshardFunction;
// The global dimensions(shape) // The global dimensions(shape)
DDim dims_; DDim dims_;
// The distributed attributes // The distributed attributes
......
...@@ -46,10 +46,10 @@ bool RToSReshardFunction::IsSuitable(const DistTensor& in, ...@@ -46,10 +46,10 @@ bool RToSReshardFunction::IsSuitable(const DistTensor& in,
return flag; return flag;
} }
std::shared_ptr<DistTensor> RToSReshardFunction::Eval( void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
phi::DeviceContext* dev_ctx,
const DistTensor& in, const DistTensor& in,
const TensorDistAttr& out_dist_attr) { const TensorDistAttr& out_dist_attr,
DistTensor* out) {
const auto& out_dims_mapping = out_dist_attr.dims_mapping(); const auto& out_dims_mapping = out_dist_attr.dims_mapping();
const auto& out_process_mesh = out_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh();
const DenseTensor& in_physical_tensor_cur_rank = in.value(); const DenseTensor& in_physical_tensor_cur_rank = in.value();
...@@ -63,14 +63,6 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval( ...@@ -63,14 +63,6 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
int64_t split_axis = split_axis_to_mesh_axis.begin()->first; int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second; int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second;
PADDLE_ENFORCE_LT(
mesh_axis,
out_process_mesh.ndim(),
phi::errors::OutOfRange(
"The mesh axis %lld exceed the size of process mesh %lld.",
mesh_axis,
out_process_mesh.ndim()));
int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; int64_t num_of_process = out_process_mesh.shape()[mesh_axis];
VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis
<< ". Split will use axis " << mesh_axis << " of process_mesh." << ". Split will use axis " << mesh_axis << " of process_mesh."
...@@ -86,13 +78,15 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval( ...@@ -86,13 +78,15 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
VLOG(3) << "The current process will remain the idx " VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor"; << coord_in_mesh[mesh_axis] << " piece of tensor";
out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]]; out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]];
VLOG(3) << "The shape of physical tensor after split is " VLOG(3) << "The shape of physical tensor after split is "
<< out_physical_tensor_cur_rank.dims(); << out_physical_tensor_cur_rank.dims();
return std::make_shared<DistTensor>( set_dist_props(out, out_physical_tensor_cur_rank, in.dims(), out_dist_attr);
out_physical_tensor_cur_rank, in.dims(), out_dist_attr);
} }
REGISTER_RESHARD_FUNC(RToSReshardFunction);
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -27,10 +27,10 @@ class RToSReshardFunction final : public ReshardFunction { ...@@ -27,10 +27,10 @@ class RToSReshardFunction final : public ReshardFunction {
bool IsSuitable(const DistTensor& in, bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override; const TensorDistAttr& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval( void Eval(DeviceContext* dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in, const DistTensor& in,
const TensorDistAttr& out_dist_attr) override; const TensorDistAttr& out_dist_attr,
DistTensor* out) override;
}; };
} // namespace distributed } // namespace distributed
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/concat_kernel.h"
namespace phi {
namespace distributed {
DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx,
const std::vector<const DenseTensor*>& input,
int64_t axis) {
DenseTensor result;
auto dtype = (*input.begin())->dtype();
if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const CPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const GPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The concat in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <vector>
namespace phi {
class DeviceContext;
class DenseTensor;
namespace distributed {
DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx,
const std::vector<const DenseTensor*>& input,
int64_t axis);
} // namespace distributed
} // namespace phi
...@@ -13,10 +13,53 @@ ...@@ -13,10 +13,53 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
namespace phi { namespace phi {
namespace distributed {} // namespace distributed namespace distributed {
std::shared_ptr<DistTensor> ReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
std::shared_ptr<DistTensor> out = std::make_shared<DistTensor>();
Eval(dev_ctx, in, out_dist_attr, out.get());
return out;
}
void ReshardFunction::set_dist_props(DistTensor* tensor,
const DenseTensor& value,
const DDim& dims,
const TensorDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(dist_attr.verify(vectorize(dims)),
true,
phi::errors::InvalidArgument(
"The input dist_attr and dims are improper."));
tensor->value_ = value;
tensor->dims_ = dims;
tensor->dist_attr_ = dist_attr;
}
ReshardFunction* ChooseProperReshardFunction(
const DistTensor& in, const TensorDistAttr& out_dist_attr) {
for (const auto& func : GetReshardFunctionList()) {
if (func->IsSuitable(in, out_dist_attr)) {
return func.get();
}
}
PADDLE_THROW(phi::errors::Unimplemented(
"Can not reshard from in_dist_attr=%s to out_dist_attr=%s.",
in.dist_attr().to_string(),
out_dist_attr.to_string()));
}
std::vector<std::unique_ptr<ReshardFunction>>& GetReshardFunctionList() {
static std::vector<std::unique_ptr<ReshardFunction>> func_list;
return func_list;
}
} // namespace distributed
} // namespace phi } // namespace phi
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
class DeviceContext; class DeviceContext;
...@@ -31,11 +34,35 @@ class ReshardFunction { ...@@ -31,11 +34,35 @@ class ReshardFunction {
virtual bool IsSuitable(const DistTensor& in, virtual bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) = 0; const TensorDistAttr& out_dist_attr) = 0;
virtual std::shared_ptr<DistTensor> Eval( std::shared_ptr<DistTensor> Eval(DeviceContext* dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in, const DistTensor& in,
const TensorDistAttr& out_dist_attr) = 0; const TensorDistAttr& out_dist_attr);
virtual void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) = 0;
protected:
void set_dist_props(DistTensor* tensor,
const DenseTensor& value,
const DDim& dims,
const TensorDistAttr& dist_attr);
}; };
std::vector<std::unique_ptr<ReshardFunction>>& GetReshardFunctionList();
#define REGISTER_RESHARD_FUNC(func_type) \
class __RegisterReshard_##func_type { \
public: \
__RegisterReshard_##func_type() { \
GetReshardFunctionList().emplace_back(std::make_unique<func_type>()); \
} \
}; \
static __RegisterReshard_##func_type local_reshard_func_##func_type
ReshardFunction* ChooseProperReshardFunction(
const DistTensor& in, const TensorDistAttr& out_dist_attr);
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi { namespace phi {
namespace distributed { namespace distributed {
...@@ -43,18 +44,25 @@ bool SToRReshardFunction::IsSuitable(const DistTensor& in, ...@@ -43,18 +44,25 @@ bool SToRReshardFunction::IsSuitable(const DistTensor& in,
flag &= (out_process_mesh.ndim() == 1); flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh); flag &= (in_process_mesh == out_process_mesh);
// Ensure the tensor is balanced split, or we need send/recv rather than
// all_gather
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int64_t num_of_process = in_process_mesh.size();
flag &=
(in.local_dims()[split_axis] * num_of_process == in.dims()[split_axis]);
return flag; return flag;
} }
std::shared_ptr<DistTensor> SToRReshardFunction::Eval( void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in, const DistTensor& in,
const TensorDistAttr& out_dist_attr) { const TensorDistAttr& out_dist_attr,
// TODO(liyurui): Only support transfer shard(0) to replicate for now. DistTensor* out) {
// Concat is needed when transfer shard(x) to replicate, will be supported
// later.
const DenseTensor& in_physical_tensor_cur_rank = in.value(); const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr(); const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping();
const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids(); const auto& in_process_ids = in_process_mesh.process_ids();
...@@ -64,9 +72,41 @@ std::shared_ptr<DistTensor> SToRReshardFunction::Eval( ...@@ -64,9 +72,41 @@ std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
DenseTensor out_all_gather = ReshardAllGatherFunctor( DenseTensor out_all_gather = ReshardAllGatherFunctor(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids); dev_ctx, in_physical_tensor_cur_rank, in_process_ids);
return std::make_shared<DistTensor>( std::map<int64_t, int64_t> split_axis_to_mesh_axis =
out_all_gather, out_all_gather.dims(), out_dist_attr); GetSplitAxisWithDimsMapping(in_dims_mapping);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
if (split_axis == 0) {
// If the input dist tensor is shard(0), the subsequent split
// and concat is unnecessary.
set_dist_props(out, out_all_gather, out_all_gather.dims(), out_dist_attr);
} else {
// Since the result of all_gather always concat the tensor on axis 0,
// first we need to split the result on axis 0,
// then we need to concat the split result on input split axis.
int64_t default_split_axis = 0;
int64_t num_of_process = in_process_ids.size();
IntArray sections(std::vector<int64_t>(
num_of_process,
in_physical_tensor_cur_rank.dims()[default_split_axis]));
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
*dev_ctx, out_all_gather, sections, default_split_axis);
// Concat the result after split on correct axis.
std::vector<const DenseTensor*> concat_input_vec;
for (const auto& tensor : split_out_vec) {
concat_input_vec.emplace_back(&tensor);
}
DenseTensor concat_out_tensor =
ReshardConcatFunctor(*dev_ctx, concat_input_vec, split_axis);
set_dist_props(
out, concat_out_tensor, concat_out_tensor.dims(), out_dist_attr);
}
} }
REGISTER_RESHARD_FUNC(SToRReshardFunction);
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -26,10 +26,10 @@ class SToRReshardFunction final : public ReshardFunction { ...@@ -26,10 +26,10 @@ class SToRReshardFunction final : public ReshardFunction {
bool IsSuitable(const DistTensor& in, bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override; const TensorDistAttr& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval( void Eval(DeviceContext* dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in, const DistTensor& in,
const TensorDistAttr& out_dist_attr) override; const TensorDistAttr& out_dist_attr,
DistTensor* out) override;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -27,9 +27,10 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -27,9 +27,10 @@ void ConcatKernel(const Context& dev_ctx,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Concat(const Context& dev_ctx, void Concat(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x, const std::vector<const DenseTensor*>& x,
const Scalar& axis) { const Scalar& axis,
DenseTensor* dense_out) {
std::vector<MetaTensor> meta_x; std::vector<MetaTensor> meta_x;
meta_x.reserve(x.size()); meta_x.reserve(x.size());
std::vector<const MetaTensor*> meta_x_ptr; std::vector<const MetaTensor*> meta_x_ptr;
...@@ -38,10 +39,17 @@ DenseTensor Concat(const Context& dev_ctx, ...@@ -38,10 +39,17 @@ DenseTensor Concat(const Context& dev_ctx,
meta_x_ptr.push_back(&meta_x.back()); meta_x_ptr.push_back(&meta_x.back());
} }
DenseTensor dense_out; MetaTensor meta_out(dense_out);
MetaTensor meta_out(&dense_out);
ConcatInferMeta(meta_x_ptr, axis.to<int>(), &meta_out); ConcatInferMeta(meta_x_ptr, axis.to<int>(), &meta_out);
ConcatKernel<T, Context>(dev_ctx, x, axis, &dense_out); ConcatKernel<T, Context>(dev_ctx, x, axis, dense_out);
}
template <typename T, typename Context>
DenseTensor Concat(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const Scalar& axis) {
DenseTensor dense_out;
Concat<T, Context>(dev_ctx, x, axis, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -29,6 +29,8 @@ PD_REGISTER_KERNEL(concat_grad, ...@@ -29,6 +29,8 @@ PD_REGISTER_KERNEL(concat_grad,
bool, bool,
int64_t, int64_t,
int, int,
int8_t,
int16_t,
uint8_t, uint8_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
......
...@@ -125,6 +125,7 @@ PD_REGISTER_KERNEL(concat, ...@@ -125,6 +125,7 @@ PD_REGISTER_KERNEL(concat,
int, int,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
......
...@@ -31,6 +31,8 @@ PD_REGISTER_KERNEL(concat_grad, ...@@ -31,6 +31,8 @@ PD_REGISTER_KERNEL(concat_grad,
int64_t, int64_t,
int, int,
uint8_t, uint8_t,
int8_t,
int16_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
......
...@@ -122,6 +122,7 @@ PD_REGISTER_KERNEL(concat, ...@@ -122,6 +122,7 @@ PD_REGISTER_KERNEL(concat,
int, int,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
......
...@@ -85,7 +85,7 @@ std::vector<DenseTensor> Split(const Context& dev_ctx, ...@@ -85,7 +85,7 @@ std::vector<DenseTensor> Split(const Context& dev_ctx,
size_t out_number = sections.GetData().size(); size_t out_number = sections.GetData().size();
std::vector<DenseTensor> result(out_number); std::vector<DenseTensor> result(out_number);
Split(dev_ctx, x, sections, axis, &result); Split<T, Context>(dev_ctx, x, sections, axis, &result);
return result; return result;
} }
......
...@@ -68,7 +68,8 @@ class TestReshardRToS: ...@@ -68,7 +68,8 @@ class TestReshardRToS:
else out_shape[self._shard] // 2 + 1 else out_shape[self._shard] // 2 + 1
) )
assert np.equal(out.numpy().shape, out_shape).all() assert np.equal(out.shape, input_tensor.shape).all()
assert np.equal(out._local_shape, out_shape).all()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -57,9 +57,8 @@ class TestReshardSToR: ...@@ -57,9 +57,8 @@ class TestReshardSToR:
assert reshard_func.is_suitable(input_tensor, out_dist_attr) assert reshard_func.is_suitable(input_tensor, out_dist_attr)
out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)
out_shape = list(self._shape) assert np.equal(out.shape, out._local_shape).all()
out_shape[self._shard] = out_shape[self._shard] * 2 assert np.equal(out.shape, input_tensor.shape).all()
assert np.equal(out.shape, out_shape).all()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -25,7 +25,7 @@ class TestDistTensor(unittest.TestCase): ...@@ -25,7 +25,7 @@ class TestDistTensor(unittest.TestCase):
def test_dist_tensor_creation(self): def test_dist_tensor_creation(self):
shape = [10, 5] shape = [10, 5]
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None])
# create dist tensor using numpy # create dist tensor using numpy
dist_tensor_with_numpy = dist.shard_tensor( dist_tensor_with_numpy = dist.shard_tensor(
......
...@@ -24,9 +24,9 @@ class TestReshardSToR(test_base.CommunicationTestDistBase): ...@@ -24,9 +24,9 @@ class TestReshardSToR(test_base.CommunicationTestDistBase):
"shape": "(10, 20)", "shape": "(10, 20)",
"dtype": "float32", "dtype": "float32",
"seeds": str(self._seeds), "seeds": str(self._seeds),
"shard": "0",
} }
self._changeable_envs = { self._changeable_envs = {
"shard": ["0", "1"],
"backend": ["cpu", "gpu"], "backend": ["cpu", "gpu"],
} }
......
...@@ -27,7 +27,7 @@ class TestDistAttrBasic(unittest.TestCase): ...@@ -27,7 +27,7 @@ class TestDistAttrBasic(unittest.TestCase):
exception = None exception = None
try: try:
mesh = [[0, 1], [2, 3]] mesh = [[0, 1], [2, 3]]
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None])
except ValueError as ex: except ValueError as ex:
self.assertIn( self.assertIn(
"The mesh must be an instance of paddle.distributed.ProcessMesh", "The mesh must be an instance of paddle.distributed.ProcessMesh",
...@@ -44,7 +44,7 @@ class TestDistAttrBasic(unittest.TestCase): ...@@ -44,7 +44,7 @@ class TestDistAttrBasic(unittest.TestCase):
[[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"] [[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]
) )
dist_attr = dist.DistAttr( dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs={"x": 0, "y": 1} mesh=mesh, sharding_specs={"x": None, "y": None}
) )
except ValueError as ex: except ValueError as ex:
self.assertIn( self.assertIn(
...@@ -63,7 +63,7 @@ class TestShardTensorDynamic(unittest.TestCase): ...@@ -63,7 +63,7 @@ class TestShardTensorDynamic(unittest.TestCase):
def test_dynamic(self): def test_dynamic(self):
dist_attr = dist.DistAttr( dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None, None] mesh=self.mesh, sharding_specs=[None, None, None]
) )
input = paddle.rand([4, 1024, 512]) input = paddle.rand([4, 1024, 512])
...@@ -71,7 +71,7 @@ class TestShardTensorDynamic(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TestShardTensorDynamic(unittest.TestCase):
print(dist_attr.dims_mapping) print(dist_attr.dims_mapping)
self.assertEqual(d_tensor.dist_attr.process_mesh, self.mesh) self.assertEqual(d_tensor.dist_attr.process_mesh, self.mesh)
self.assertEqual(d_tensor.dist_attr.dims_mapping, [0, -1, -1]) self.assertEqual(d_tensor.dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh")) self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh"))
self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping")) self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping"))
...@@ -111,7 +111,7 @@ class TestShardTensorStaticDy2Static(unittest.TestCase): ...@@ -111,7 +111,7 @@ class TestShardTensorStaticDy2Static(unittest.TestCase):
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"] [[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
) )
dist_attr = dist.DistAttr( dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=['x', None, None] mesh=mesh, sharding_specs=[None, None, None]
) )
input = paddle.rand([4, 1024, 512]) input = paddle.rand([4, 1024, 512])
...@@ -126,7 +126,7 @@ class TestShardTensorStaticDy2Static(unittest.TestCase): ...@@ -126,7 +126,7 @@ class TestShardTensorStaticDy2Static(unittest.TestCase):
static_tensor static_tensor
) )
self.assertEqual(dist_input.dist_attr.process_mesh, mesh) self.assertEqual(dist_input.dist_attr.process_mesh, mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1]) self.assertEqual(dist_input.dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh")) self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping")) self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
......
...@@ -46,20 +46,14 @@ TEST(dist_tensor, constructor) { ...@@ -46,20 +46,14 @@ TEST(dist_tensor, constructor) {
EXPECT_EQ(dist_x1.local_dims()[0], 3L); EXPECT_EQ(dist_x1.local_dims()[0], 3L);
EXPECT_EQ(dist_x1.local_dims()[1], 4L); EXPECT_EQ(dist_x1.local_dims()[1], 4L);
DenseTensor x2(alloc, meta);
DistTensor dist_x2(x2, dims, dist_attr);
EXPECT_TRUE(dist_x2.defined());
EXPECT_TRUE(dist_x2.initialized());
EXPECT_TRUE(dist_x1.valid());
// empty construct // empty construct
DistTensor dist_x3(dims, dist_attr); DistTensor dist_x2(dims, dist_attr);
EXPECT_TRUE(!dist_x3.defined()); EXPECT_TRUE(!dist_x2.defined());
EXPECT_TRUE(!dist_x3.initialized()); EXPECT_TRUE(!dist_x2.initialized());
// allocate error test // allocate error test
bool caught_exception = false; bool caught_exception = false;
try { try {
dist_x3.AllocateFrom(alloc, phi::DataType::FLOAT32, 12L, false); dist_x2.AllocateFrom(alloc, phi::DataType::FLOAT32, 12L, false);
} catch (phi::EnforceNotMet& error) { } catch (phi::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
EXPECT_NE(std::string(error.what()).find("Unavailable"), 0UL); EXPECT_NE(std::string(error.what()).find("Unavailable"), 0UL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册