未验证 提交 99795a13 编写于 作者: L LiYuRio 提交者: GitHub

[Reshard] Support create shard tensor and non-zero dim reshard (#56553)

* support create shard dist tesnor

* support non-zero shard to replicated

* change reshard signature
上级 d3f4596a
......@@ -388,6 +388,26 @@ PyObject* tensor_properties_get_dist_attr(TensorObject* self, void* closure) {
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyObject* tensor_properties_get_local_shape(TensorObject* self, void* closure) {
EAGER_TRY
if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
return ToPyObject(phi::vectorize<int64_t>(dist_tensor->local_dims()));
#else
PADDLE_THROW(platform::errors::Unavailable(
"The `_local_shape` property of (Dist)Tensor is not supported "
"in the current PaddlePaddle, please recompile and install "
"PaddlePaddle "
"with the option of `WITH_DISTRIBUTE=ON`."));
#endif
} else {
RETURN_PY_NONE
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyDoc_STRVAR(tensor_shape__doc__,
R"DOC(shape
......@@ -716,6 +736,11 @@ struct PyGetSetDef variable_properties[] = { // NOLINT
(setter)tensor_properties_set_persistable,
tensor_persistable__doc__,
nullptr},
{"_local_shape",
(getter)tensor_properties_get_local_shape,
nullptr,
nullptr,
nullptr},
{"shape",
(getter)tensor_properties_get_shape,
nullptr,
......
......@@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE)
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_concat_functor.cc
reshard_all_gather_functor.cc
r_to_s_reshard_function.cc
s_to_r_reshard_function.cc)
......
......@@ -14,6 +14,10 @@
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
namespace phi {
namespace distributed {
......@@ -27,15 +31,21 @@ inline void check_defined(const DistTensor& dist_tensor,
method_hint));
}
// TODO(chenweihang): Reshard the input global value into local value
DistTensor::DistTensor(const phi::DenseTensor& global_value,
const TensorDistAttr& dist_attr)
: dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {}
DistTensor::DistTensor(const phi::DenseTensor& value,
const DDim& dims,
const TensorDistAttr& dist_attr)
: dims_(dims), dist_attr_(dist_attr), value_(value) {}
: dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {
if (!IsDimsMappingReplicated(dist_attr_.dims_mapping())) {
// 1. create replicated global tensor
int64_t dims_size = global_value.dims().size();
std::vector<int64_t> dims_mapping(dims_size, -1);
dist_attr_.set_dims_mapping(dims_mapping);
// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(*this, dist_attr);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place());
func->Eval(dev_ctx, *this, dist_attr, this);
}
}
DistTensor::DistTensor(const DDim& dims, const TensorDistAttr& dist_attr)
: dims_(dims), dist_attr_(dist_attr) {}
......
......@@ -21,26 +21,23 @@
namespace phi {
namespace distributed {
class ReshardFunction;
class DistTensor final
: public phi::TensorBase,
public phi::TypeInfoTraits<phi::TensorBase, DistTensor> {
public:
/// \brief Careful to create dist tensor using default constructor.
/// this should only used in reshard for now, and the dist properties
/// will be set by reshard later.
DistTensor() = default;
/// \brief Construct a dist tensor based dense tensor.
/// \param global_value The global dense tensor of the current tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const phi::DenseTensor& global_value,
const TensorDistAttr& dist_attr);
// TODO(chenweihang): Remove this constructor after added reshard impl
/// \brief Construct a dist tensor based dense tensor.
/// \param value The local dense tensor of the current tensor.
/// \param dims The global dimension of the currnet tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const phi::DenseTensor& value,
const DDim& dims,
const TensorDistAttr& dist_attr);
/// \brief Construct a empty dist tensor (for infer spmd)
/// \param dims The global dimension of the currnet Tensor.
/// \param dist_attr The distributed attributes of the current tensor.
......@@ -109,6 +106,8 @@ class DistTensor final
bool fake_alloc = false) override;
private:
friend class ReshardFunction;
// The global dimensions(shape)
DDim dims_;
// The distributed attributes
......
......@@ -46,10 +46,10 @@ bool RToSReshardFunction::IsSuitable(const DistTensor& in,
return flag;
}
std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
const auto& out_dims_mapping = out_dist_attr.dims_mapping();
const auto& out_process_mesh = out_dist_attr.process_mesh();
const DenseTensor& in_physical_tensor_cur_rank = in.value();
......@@ -63,14 +63,6 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second;
PADDLE_ENFORCE_LT(
mesh_axis,
out_process_mesh.ndim(),
phi::errors::OutOfRange(
"The mesh axis %lld exceed the size of process mesh %lld.",
mesh_axis,
out_process_mesh.ndim()));
int64_t num_of_process = out_process_mesh.shape()[mesh_axis];
VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis
<< ". Split will use axis " << mesh_axis << " of process_mesh."
......@@ -86,13 +78,15 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor";
out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]];
VLOG(3) << "The shape of physical tensor after split is "
<< out_physical_tensor_cur_rank.dims();
return std::make_shared<DistTensor>(
out_physical_tensor_cur_rank, in.dims(), out_dist_attr);
set_dist_props(out, out_physical_tensor_cur_rank, in.dims(), out_dist_attr);
}
REGISTER_RESHARD_FUNC(RToSReshardFunction);
} // namespace distributed
} // namespace phi
......@@ -27,10 +27,10 @@ class RToSReshardFunction final : public ReshardFunction {
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;
};
} // namespace distributed
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/concat_kernel.h"
namespace phi {
namespace distributed {
DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx,
const std::vector<const DenseTensor*>& input,
int64_t axis) {
DenseTensor result;
auto dtype = (*input.begin())->dtype();
if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const CPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const GPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The concat in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <vector>
namespace phi {
class DeviceContext;
class DenseTensor;
namespace distributed {
DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx,
const std::vector<const DenseTensor*>& input,
int64_t axis);
} // namespace distributed
} // namespace phi
......@@ -13,10 +13,53 @@
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
namespace phi {
namespace distributed {} // namespace distributed
namespace distributed {
std::shared_ptr<DistTensor> ReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
std::shared_ptr<DistTensor> out = std::make_shared<DistTensor>();
Eval(dev_ctx, in, out_dist_attr, out.get());
return out;
}
void ReshardFunction::set_dist_props(DistTensor* tensor,
const DenseTensor& value,
const DDim& dims,
const TensorDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(dist_attr.verify(vectorize(dims)),
true,
phi::errors::InvalidArgument(
"The input dist_attr and dims are improper."));
tensor->value_ = value;
tensor->dims_ = dims;
tensor->dist_attr_ = dist_attr;
}
ReshardFunction* ChooseProperReshardFunction(
const DistTensor& in, const TensorDistAttr& out_dist_attr) {
for (const auto& func : GetReshardFunctionList()) {
if (func->IsSuitable(in, out_dist_attr)) {
return func.get();
}
}
PADDLE_THROW(phi::errors::Unimplemented(
"Can not reshard from in_dist_attr=%s to out_dist_attr=%s.",
in.dist_attr().to_string(),
out_dist_attr.to_string()));
}
std::vector<std::unique_ptr<ReshardFunction>>& GetReshardFunctionList() {
static std::vector<std::unique_ptr<ReshardFunction>> func_list;
return func_list;
}
} // namespace distributed
} // namespace phi
......@@ -14,6 +14,9 @@
#pragma once
#include <memory>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
class DeviceContext;
......@@ -31,11 +34,35 @@ class ReshardFunction {
virtual bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) = 0;
virtual std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) = 0;
std::shared_ptr<DistTensor> Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr);
virtual void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) = 0;
protected:
void set_dist_props(DistTensor* tensor,
const DenseTensor& value,
const DDim& dims,
const TensorDistAttr& dist_attr);
};
std::vector<std::unique_ptr<ReshardFunction>>& GetReshardFunctionList();
#define REGISTER_RESHARD_FUNC(func_type) \
class __RegisterReshard_##func_type { \
public: \
__RegisterReshard_##func_type() { \
GetReshardFunctionList().emplace_back(std::make_unique<func_type>()); \
} \
}; \
static __RegisterReshard_##func_type local_reshard_func_##func_type
ReshardFunction* ChooseProperReshardFunction(
const DistTensor& in, const TensorDistAttr& out_dist_attr);
} // namespace distributed
} // namespace phi
......@@ -15,12 +15,13 @@
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#include "glog/logging.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi {
namespace distributed {
......@@ -43,18 +44,25 @@ bool SToRReshardFunction::IsSuitable(const DistTensor& in,
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);
// Ensure the tensor is balanced split, or we need send/recv rather than
// all_gather
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int64_t num_of_process = in_process_mesh.size();
flag &=
(in.local_dims()[split_axis] * num_of_process == in.dims()[split_axis]);
return flag;
}
std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
// TODO(liyurui): Only support transfer shard(0) to replicate for now.
// Concat is needed when transfer shard(x) to replicate, will be supported
// later.
void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
......@@ -64,9 +72,41 @@ std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
DenseTensor out_all_gather = ReshardAllGatherFunctor(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids);
return std::make_shared<DistTensor>(
out_all_gather, out_all_gather.dims(), out_dist_attr);
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
if (split_axis == 0) {
// If the input dist tensor is shard(0), the subsequent split
// and concat is unnecessary.
set_dist_props(out, out_all_gather, out_all_gather.dims(), out_dist_attr);
} else {
// Since the result of all_gather always concat the tensor on axis 0,
// first we need to split the result on axis 0,
// then we need to concat the split result on input split axis.
int64_t default_split_axis = 0;
int64_t num_of_process = in_process_ids.size();
IntArray sections(std::vector<int64_t>(
num_of_process,
in_physical_tensor_cur_rank.dims()[default_split_axis]));
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
*dev_ctx, out_all_gather, sections, default_split_axis);
// Concat the result after split on correct axis.
std::vector<const DenseTensor*> concat_input_vec;
for (const auto& tensor : split_out_vec) {
concat_input_vec.emplace_back(&tensor);
}
DenseTensor concat_out_tensor =
ReshardConcatFunctor(*dev_ctx, concat_input_vec, split_axis);
set_dist_props(
out, concat_out_tensor, concat_out_tensor.dims(), out_dist_attr);
}
}
REGISTER_RESHARD_FUNC(SToRReshardFunction);
} // namespace distributed
} // namespace phi
......@@ -26,10 +26,10 @@ class SToRReshardFunction final : public ReshardFunction {
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;
};
} // namespace distributed
......
......@@ -27,9 +27,10 @@ void ConcatKernel(const Context& dev_ctx,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Concat(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const Scalar& axis) {
void Concat(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const Scalar& axis,
DenseTensor* dense_out) {
std::vector<MetaTensor> meta_x;
meta_x.reserve(x.size());
std::vector<const MetaTensor*> meta_x_ptr;
......@@ -38,10 +39,17 @@ DenseTensor Concat(const Context& dev_ctx,
meta_x_ptr.push_back(&meta_x.back());
}
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
MetaTensor meta_out(dense_out);
ConcatInferMeta(meta_x_ptr, axis.to<int>(), &meta_out);
ConcatKernel<T, Context>(dev_ctx, x, axis, &dense_out);
ConcatKernel<T, Context>(dev_ctx, x, axis, dense_out);
}
template <typename T, typename Context>
DenseTensor Concat(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const Scalar& axis) {
DenseTensor dense_out;
Concat<T, Context>(dev_ctx, x, axis, &dense_out);
return dense_out;
}
......
......@@ -29,6 +29,8 @@ PD_REGISTER_KERNEL(concat_grad,
bool,
int64_t,
int,
int8_t,
int16_t,
uint8_t,
phi::dtype::float16,
phi::dtype::complex<float>,
......
......@@ -125,6 +125,7 @@ PD_REGISTER_KERNEL(concat,
int,
uint8_t,
int8_t,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
......
......@@ -31,6 +31,8 @@ PD_REGISTER_KERNEL(concat_grad,
int64_t,
int,
uint8_t,
int8_t,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
......
......@@ -122,6 +122,7 @@ PD_REGISTER_KERNEL(concat,
int,
uint8_t,
int8_t,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
......
......@@ -85,7 +85,7 @@ std::vector<DenseTensor> Split(const Context& dev_ctx,
size_t out_number = sections.GetData().size();
std::vector<DenseTensor> result(out_number);
Split(dev_ctx, x, sections, axis, &result);
Split<T, Context>(dev_ctx, x, sections, axis, &result);
return result;
}
......
......@@ -68,7 +68,8 @@ class TestReshardRToS:
else out_shape[self._shard] // 2 + 1
)
assert np.equal(out.numpy().shape, out_shape).all()
assert np.equal(out.shape, input_tensor.shape).all()
assert np.equal(out._local_shape, out_shape).all()
if __name__ == '__main__':
......
......@@ -57,9 +57,8 @@ class TestReshardSToR:
assert reshard_func.is_suitable(input_tensor, out_dist_attr)
out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)
out_shape = list(self._shape)
out_shape[self._shard] = out_shape[self._shard] * 2
assert np.equal(out.shape, out_shape).all()
assert np.equal(out.shape, out._local_shape).all()
assert np.equal(out.shape, input_tensor.shape).all()
if __name__ == '__main__':
......
......@@ -25,7 +25,7 @@ class TestDistTensor(unittest.TestCase):
def test_dist_tensor_creation(self):
shape = [10, 5]
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None])
# create dist tensor using numpy
dist_tensor_with_numpy = dist.shard_tensor(
......
......@@ -24,9 +24,9 @@ class TestReshardSToR(test_base.CommunicationTestDistBase):
"shape": "(10, 20)",
"dtype": "float32",
"seeds": str(self._seeds),
"shard": "0",
}
self._changeable_envs = {
"shard": ["0", "1"],
"backend": ["cpu", "gpu"],
}
......
......@@ -27,7 +27,7 @@ class TestDistAttrBasic(unittest.TestCase):
exception = None
try:
mesh = [[0, 1], [2, 3]]
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None])
except ValueError as ex:
self.assertIn(
"The mesh must be an instance of paddle.distributed.ProcessMesh",
......@@ -44,7 +44,7 @@ class TestDistAttrBasic(unittest.TestCase):
[[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]
)
dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs={"x": 0, "y": 1}
mesh=mesh, sharding_specs={"x": None, "y": None}
)
except ValueError as ex:
self.assertIn(
......@@ -63,7 +63,7 @@ class TestShardTensorDynamic(unittest.TestCase):
def test_dynamic(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None, None]
mesh=self.mesh, sharding_specs=[None, None, None]
)
input = paddle.rand([4, 1024, 512])
......@@ -71,7 +71,7 @@ class TestShardTensorDynamic(unittest.TestCase):
print(dist_attr.dims_mapping)
self.assertEqual(d_tensor.dist_attr.process_mesh, self.mesh)
self.assertEqual(d_tensor.dist_attr.dims_mapping, [0, -1, -1])
self.assertEqual(d_tensor.dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh"))
self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping"))
......@@ -111,7 +111,7 @@ class TestShardTensorStaticDy2Static(unittest.TestCase):
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)
dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=['x', None, None]
mesh=mesh, sharding_specs=[None, None, None]
)
input = paddle.rand([4, 1024, 512])
......@@ -126,7 +126,7 @@ class TestShardTensorStaticDy2Static(unittest.TestCase):
static_tensor
)
self.assertEqual(dist_input.dist_attr.process_mesh, mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertEqual(dist_input.dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
......
......@@ -46,20 +46,14 @@ TEST(dist_tensor, constructor) {
EXPECT_EQ(dist_x1.local_dims()[0], 3L);
EXPECT_EQ(dist_x1.local_dims()[1], 4L);
DenseTensor x2(alloc, meta);
DistTensor dist_x2(x2, dims, dist_attr);
EXPECT_TRUE(dist_x2.defined());
EXPECT_TRUE(dist_x2.initialized());
EXPECT_TRUE(dist_x1.valid());
// empty construct
DistTensor dist_x3(dims, dist_attr);
EXPECT_TRUE(!dist_x3.defined());
EXPECT_TRUE(!dist_x3.initialized());
DistTensor dist_x2(dims, dist_attr);
EXPECT_TRUE(!dist_x2.defined());
EXPECT_TRUE(!dist_x2.initialized());
// allocate error test
bool caught_exception = false;
try {
dist_x3.AllocateFrom(alloc, phi::DataType::FLOAT32, 12L, false);
dist_x2.AllocateFrom(alloc, phi::DataType::FLOAT32, 12L, false);
} catch (phi::EnforceNotMet& error) {
caught_exception = true;
EXPECT_NE(std::string(error.what()).find("Unavailable"), 0UL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册