未验证 提交 8495377a 编写于 作者: C Chen Weihang 提交者: GitHub

[AutoParallel] Polish dist tensor design (#56368)

* polish dist teensor design

* adjust constructor

* polish details

* polish details design

* fix compile error

* refactor init tensor impl

* fix reshard test

* polish details

* add unittest for coverage
上级 ffff3da0
......@@ -112,7 +112,7 @@ void GradNodeBase::SetGradInMeta(const paddle::Tensor& fwd_out,
return;
}
phi::DenseTensor* dense_tensor = nullptr;
const phi::DenseTensor* dense_tensor = nullptr;
// Record TensorMeta
if (phi::DenseTensor::classof(fwd_out.impl().get())) {
// Only Copy Meta
......@@ -130,8 +130,8 @@ void GradNodeBase::SetGradInMeta(const paddle::Tensor& fwd_out,
// TODO(chenweihang): DistTensor contains global and local meta, here
// only set the local meta now, we should set global meta later
dense_tensor =
static_cast<phi::distributed::DistTensor*>(fwd_out.impl().get())
->mutable_value();
&(static_cast<phi::distributed::DistTensor*>(fwd_out.impl().get())
->value());
#endif
} else {
VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
......@@ -270,16 +270,16 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in,
meta.SetPlace(fwd_in.place());
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(fwd_in.impl().get())) {
phi::DenseTensor* dense_tensor =
const phi::DenseTensor& dense_tensor =
static_cast<phi::distributed::DistTensor*>(fwd_in.impl().get())
->mutable_value();
->value();
PADDLE_ENFORCE_NE(
dense_tensor->meta().dtype,
dense_tensor.meta().dtype,
phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetTensorMeta(dense_tensor.meta());
meta.SetPlace(fwd_in.place());
#endif
} else {
......
......@@ -94,12 +94,9 @@ void GradTensorHolder::CopyValueFromTensor(size_t slot_id,
// TODO(chenweihang): replace by valid dist_attr later
auto temp =
paddle::experimental::full(t.shape(), 1, t.dtype(), t.place());
auto dense_temp =
std::dynamic_pointer_cast<phi::DenseTensor>(temp.impl());
auto dense_temp = static_cast<phi::DenseTensor*>(temp.impl().get());
auto dist_tensor = std::make_shared<phi::distributed::DistTensor>(
dense_temp,
dense_temp->meta(),
std::make_shared<phi::distributed::TensorDistAttr>());
*dense_temp, phi::distributed::TensorDistAttr());
temp.set_impl(dist_tensor);
buffer_[slot_id][rank] = temp;
#endif
......
......@@ -121,8 +121,7 @@ void BindAutoParallel(py::module *m) {
"is_suitable",
[](phi::distributed::ReshardFunction &self,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
const phi::distributed::TensorDistAttr &dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
......@@ -135,8 +134,7 @@ void BindAutoParallel(py::module *m) {
[](phi::distributed::ReshardFunction &self,
phi::DeviceContext *dev_ctx,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
const phi::distributed::TensorDistAttr &dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
......@@ -281,8 +279,7 @@ void BindAutoParallel(py::module *m) {
py::arg("memo"))
.def("__str__", &DeviceMesh::to_string);
py::class_<TensorDistAttr, std::shared_ptr<TensorDistAttr>> py_dist_attr(
*m, "TensorDistAttr");
py::class_<TensorDistAttr> py_dist_attr(*m, "TensorDistAttr");
g_tensor_dist_attr_pytype =
reinterpret_cast<PyTypeObject *>(py_dist_attr.ptr());
py_dist_attr.def(py::init<>())
......
......@@ -68,52 +68,6 @@ PyObject* TensorNew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
return obj;
}
#ifdef PADDLE_WITH_DISTRIBUTE
void EmptyDistTensorInitializer(
TensorObject* self,
const std::string& name,
const paddle::platform::Place& place,
const std::shared_ptr<TensorDistAttr>& dist_attr,
bool persistable = false,
int stop_gradient = -1,
framework::proto::VarType::Type dtype =
paddle::framework::proto::VarType::FP32,
const std::vector<int>& dims = {0}) {
auto ddims = phi::make_ddim(dims);
self->tensor.set_name(name);
auto autograd_meta = egr::EagerUtils::autograd_meta(&(self->tensor));
autograd_meta->SetPersistable(persistable);
if (stop_gradient != -1) {
autograd_meta->SetStopGradient(static_cast<bool>(stop_gradient));
}
std::shared_ptr<DistTensor> dist_tensor = nullptr;
if (dims.size() == 1 && dims[0] == 0) {
std::shared_ptr<phi::Allocation> allocation_ptr = nullptr;
dist_tensor = std::make_shared<DistTensor>(
allocation_ptr,
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
ddims),
dist_attr);
} else {
dist_tensor = std::make_shared<DistTensor>(
std::make_shared<phi::Allocation>(),
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
ddims),
dist_attr);
}
self->tensor.set_impl(dist_tensor);
if (!autograd_meta->GetMutableGradNode()) {
autograd_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
VLOG(3) << "Tensor(" << name
<< ") have not GradNode, add GradNodeAccumulation"
<< autograd_meta->GradNode() << " for it.";
}
}
#endif
// TODO(jiabin): Overload this once we need more constructor in Python
void EmptyTensorInitializer(TensorObject* self,
const std::string& name,
......@@ -184,44 +138,71 @@ void EmptyStringTensorInitializer(TensorObject* self,
}
#ifdef PADDLE_WITH_DISTRIBUTE
void InitDistTensorWithNumpyValue(TensorObject* self,
const py::object& array,
const paddle::platform::Place& place,
bool zero_copy = false) {
PADDLE_ENFORCE_EQ(
self->tensor.defined(),
true,
paddle::platform::errors::Unavailable(
"Calling InitDistTensorWithNumpyValue of Eager Tensor without "
"EmptyDistTensorInitializer is "
"forbidden. Please check your code and make sure you new a "
"eager tensor before init it with NumPy."));
DistTensor* dist_tensor_ptr =
static_cast<DistTensor*>(self->tensor.impl().get());
phi::DenseTensor* impl_ptr =
static_cast<phi::DenseTensor*>(dist_tensor_ptr->mutable_value());
void CreateDistTensorWithNumpyValue(TensorObject* self,
const std::string& name,
const paddle::platform::Place& place,
const TensorDistAttr& dist_attr,
const py::object& array,
bool persistable = false,
int stop_gradient = -1,
bool zero_copy = false,
framework::proto::VarType::Type dtype =
paddle::framework::proto::VarType::FP32,
const std::vector<int>& dims = {0}) {
auto ddims = phi::make_ddim(dims);
self->tensor.set_name(name);
auto autograd_meta = egr::EagerUtils::autograd_meta(&(self->tensor));
autograd_meta->SetPersistable(persistable);
if (stop_gradient != -1) {
autograd_meta->SetStopGradient(static_cast<bool>(stop_gradient));
}
phi::DenseTensor dense_tensor;
if (dims.size() == 1 && dims[0] == 0) {
std::shared_ptr<phi::Allocation> allocation_ptr = nullptr;
dense_tensor = phi::DenseTensor(
nullptr,
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
ddims));
} else {
dense_tensor = phi::DenseTensor(
std::make_shared<phi::Allocation>(),
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
ddims));
}
if (platform::is_cpu_place(place)) {
SetTensorFromPyArray<platform::CPUPlace>(impl_ptr, array, place, zero_copy);
SetTensorFromPyArray<platform::CPUPlace>(
&dense_tensor, array, place, zero_copy);
} else if (platform::is_xpu_place(place)) {
SetTensorFromPyArray<platform::XPUPlace>(impl_ptr, array, place, zero_copy);
SetTensorFromPyArray<platform::XPUPlace>(
&dense_tensor, array, place, zero_copy);
} else if (platform::is_gpu_place(place)) {
SetTensorFromPyArray<platform::CUDAPlace>(
impl_ptr, array, place, zero_copy);
&dense_tensor, array, place, zero_copy);
} else if (platform::is_cuda_pinned_place(place)) {
SetTensorFromPyArray<platform::CUDAPinnedPlace>(
impl_ptr, array, place, zero_copy);
&dense_tensor, array, place, zero_copy);
} else if (platform::is_custom_place(place)) {
SetTensorFromPyArray<platform::CustomPlace>(
impl_ptr, array, place, zero_copy);
&dense_tensor, array, place, zero_copy);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of "
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/CustomPlace"));
}
// TODO(dev): dist_tensor meta is not equal to dense tensor meta
dist_tensor_ptr->set_meta(impl_ptr->meta());
auto dist_tensor =
std::make_shared<phi::distributed::DistTensor>(dense_tensor, dist_attr);
self->tensor.set_impl(dist_tensor);
if (!autograd_meta->GetMutableGradNode()) {
autograd_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
VLOG(3) << "Tensor(" << name
<< ") have not GradNode, add GradNodeAccumulation"
<< autograd_meta->GradNode() << " for it.";
}
}
#endif
......@@ -284,12 +265,11 @@ void InitStringTensorWithNumpyValue(TensorObject* self, const py::object& obj) {
}
#ifdef PADDLE_WITH_DISTRIBUTE
void InitDistTensorWithTensor(
TensorObject* self,
const paddle::Tensor& src,
const paddle::platform::Place& place,
const std::string& name,
const std::shared_ptr<TensorDistAttr>& dist_attr) {
void InitDistTensorWithTensor(TensorObject* self,
const paddle::Tensor& src,
const paddle::platform::Place& place,
const std::string& name,
const TensorDistAttr& dist_attr) {
PADDLE_ENFORCE(src.is_dense_tensor(),
paddle::platform::errors::InvalidArgument(
"DistTensor can only initialize by DenseTensor"));
......@@ -297,15 +277,13 @@ void InitDistTensorWithTensor(
if (place == src.place()) {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(src.impl());
self->tensor.set_impl(
std::make_shared<DistTensor>(tensor, tensor->meta(), dist_attr));
self->tensor.set_impl(std::make_shared<DistTensor>(*tensor, dist_attr));
VLOG(4) << "Same place, do ShareDataWith for DistTensor.";
} else {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(
src.copy_to(place, true).impl());
self->tensor.set_impl(
std::make_shared<DistTensor>(tensor, tensor->meta(), dist_attr));
self->tensor.set_impl(std::make_shared<DistTensor>(*tensor, dist_attr));
VLOG(4) << "Different place, do TensorCopy for DistTensor.";
}
if (src.get_autograd_meta()) {
......@@ -416,13 +394,13 @@ paddle::platform::Place ParsePlace(
}
#ifdef PADDLE_WITH_DISTRIBUTE
std::shared_ptr<TensorDistAttr> ParseDistAttrArgs(
TensorDistAttr ParseDistAttrArgs(
std::unordered_map<std::string, PyObject*> kws_map,
std::unordered_map<std::string, Py_ssize_t> kw_order_map,
PyObject* args,
bool flag_kwargs,
Py_ssize_t args_num) {
std::shared_ptr<TensorDistAttr> dist_attr = nullptr;
TensorDistAttr dist_attr;
if (kw_order_map["dist_attr"] <= args_num) {
dist_attr = CastPyArg2DistAttr(
PyTuple_GET_ITEM(args, kw_order_map["dist_attr"] - 1),
......@@ -530,13 +508,18 @@ void AutoInitTensorByPyArray(TensorObject* py_tensor_ptr,
"stop_gradient", kws_map, kw_order_map, args, flag_kwargs, args_num);
#ifdef PADDLE_WITH_DISTRIBUTE
std::shared_ptr<TensorDistAttr> dist_attr =
TensorDistAttr dist_attr =
ParseDistAttrArgs(kws_map, kw_order_map, args, flag_kwargs, args_num);
if (dist_attr) {
EmptyDistTensorInitializer(
py_tensor_ptr, act_name, place, dist_attr, persistable, stop_gradient);
InitDistTensorWithNumpyValue(py_tensor_ptr, numpy_value, place, zero_copy);
if (!dist_attr.empty()) {
CreateDistTensorWithNumpyValue(py_tensor_ptr,
act_name,
place,
dist_attr,
numpy_value,
persistable,
stop_gradient,
zero_copy);
return;
}
#endif
......@@ -572,7 +555,7 @@ void AutoInitTensorByTensor(TensorObject* py_tensor_ptr,
act_name = ParseName(kws_map, kw_order_map, args, flag_kwargs, args_num);
#ifdef PADDLE_WITH_DISTRIBUTE
std::shared_ptr<TensorDistAttr> dist_attr =
TensorDistAttr dist_attr =
ParseDistAttrArgs(kws_map, kw_order_map, args, flag_kwargs, args_num);
#endif
......@@ -595,7 +578,7 @@ void AutoInitTensorByTensor(TensorObject* py_tensor_ptr,
}
}
#ifdef PADDLE_WITH_DISTRIBUTE
if (dist_attr) {
if (!dist_attr.empty()) {
InitDistTensorWithTensor(
py_tensor_ptr, src_tensor, place, act_name, dist_attr);
} else {
......
......@@ -147,6 +147,15 @@ static PyObject* tensor_method_numpy(TensorObject* self,
return array;
}
auto tensor_dims = self->tensor.shape();
#ifdef PADDLE_WITH_DISTRIBUTE
// Now the DistTensor's numpy() return the local tensor value
if (self->tensor.is_dist_tensor()) {
tensor_dims = phi::vectorize(
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get())
->value()
.dims());
}
#endif
auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
auto sizeof_dtype = phi::SizeOf(self->tensor.type());
Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank]; // NOLINT
......
......@@ -378,7 +378,7 @@ PyObject* tensor_properties_get_dist_attr(TensorObject* self, void* closure) {
#ifdef PADDLE_WITH_DISTRIBUTE
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
return ToPyObject(dist_tensor->dist_attr().get());
return ToPyObject(&dist_tensor->dist_attr());
#else
RETURN_PY_NONE
#endif
......
......@@ -547,11 +547,10 @@ platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos) {
#ifdef PADDLE_WITH_DISTRIBUTE
using phi::distributed::TensorDistAttr;
std::shared_ptr<TensorDistAttr> CastPyArg2DistAttr(PyObject* obj,
ssize_t arg_pos) {
TensorDistAttr CastPyArg2DistAttr(PyObject* obj, ssize_t arg_pos) {
if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_tensor_dist_attr_pytype))) {
return ::pybind11::handle(obj).cast<std::shared_ptr<TensorDistAttr>>();
return ::pybind11::handle(obj).cast<TensorDistAttr>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
......
......@@ -313,8 +313,8 @@ paddle::DataType CastPyArg2DataTypeDirectly(PyObject* obj,
ssize_t arg_pos);
#ifdef PADDLE_WITH_DISTRIBUTE
std::shared_ptr<phi::distributed::TensorDistAttr> CastPyArg2DistAttr(
PyObject* obj, ssize_t arg_pos);
phi::distributed::TensorDistAttr CastPyArg2DistAttr(PyObject* obj,
ssize_t arg_pos);
#endif
paddle::optional<paddle::Tensor> GetOptionalTensorFromArgs(
......
......@@ -1029,7 +1029,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
py::class_<DistTensor>(m, "DistTensor")
.def(
"get_tensor",
[](DistTensor &self) { return self.mutable_value(); },
[](DistTensor &self) { return self.value(); },
py::return_value_policy::reference)
.def("numel",
[](DistTensor &self) -> int64_t { return self.value().numel(); });
......
......@@ -539,11 +539,9 @@ phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) {
if (out) {
// TODO(chenweihang): now all dist case are nullptr
if (out->impl() == nullptr) {
auto dense_t = std::make_shared<phi::DenseTensor>();
// TODO(chenweihang): polish code, dist_attr is null now
auto dist_attr = std::make_shared<phi::distributed::TensorDistAttr>();
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
dense_t, phi::DenseTensorMeta(), dist_attr);
phi::DDim(), phi::distributed::TensorDistAttr());
out->set_impl(dist_t);
}
return static_cast<phi::distributed::DistTensor*>(out->impl().get());
......
......@@ -247,11 +247,11 @@ void CheckAndTrans2Contiguous(phi::DenseTensor* tensor) {
}
}
phi::DenseTensor TransformData(phi::DenseTensor* tensor,
phi::DenseTensor TransformData(const phi::DenseTensor& tensor,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
phi::DenseTensor out = *tensor;
phi::DenseTensor out = tensor;
bool trans_layout = false;
bool trans_dtype = false;
......@@ -259,11 +259,11 @@ phi::DenseTensor TransformData(phi::DenseTensor* tensor,
out = Trans2Contiguous(out);
}
if (NeedTransformLayout(tensor->layout(),
if (NeedTransformLayout(tensor.layout(),
target_args_def.layout,
tensor->place(),
tensor.place(),
transform_flag) &&
tensor->dims().size() != 1) {
tensor.dims().size() != 1) {
if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
out = Trans2Contiguous(out);
}
......@@ -272,7 +272,7 @@ phi::DenseTensor TransformData(phi::DenseTensor* tensor,
}
if (NeedTransformDataType(
tensor->dtype(), target_args_def.dtype, transform_flag)) {
tensor.dtype(), target_args_def.dtype, transform_flag)) {
if (NeedTransform2Contiguous(false, out.meta().is_contiguous())) {
out = Trans2Contiguous(out);
}
......@@ -284,8 +284,14 @@ phi::DenseTensor TransformData(phi::DenseTensor* tensor,
out.place(), target_args_def.backend, transform_flag)) {
out = TransDataPlace(out, phi::TransToPhiPlace(target_args_def.backend));
if (!trans_layout && !trans_dtype &&
tensor->place().GetType() == AllocationType::GPUPINNED) {
tensor->ShareBufferWith(out);
tensor.place().GetType() == AllocationType::GPUPINNED) {
// Sharing buffer on GPUPINNED place is a special case due to historical
// reasons, and it should not be implemented in this way from a
// reasonable point of view, but because the performance of the previous
// model depends on the inplace operation here, the model performance
// will deteriorate after reverting to non-place impl, so it needs to be
// retained here and need to use `const_cast`
const_cast<phi::DenseTensor&>(tensor).ShareBufferWith(out);
}
}
return out;
......@@ -314,7 +320,7 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
}
phi::DenseTensor out = TransformData(
&dense_tensor, target_args_def, transform_flag, is_stride_kernel);
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
return std::make_shared<phi::DenseTensor>(std::move(out));
}
return nullptr;
......@@ -359,7 +365,7 @@ std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
*std::dynamic_pointer_cast<phi::DenseTensor>(tensor_in));
} else {
pt_tensors->emplace_back(
TransformData((static_cast<phi::DenseTensor*>(tensor_in.get())),
TransformData(*(static_cast<phi::DenseTensor*>(tensor_in.get())),
target_args_def,
transform_flag,
is_stride_kernel));
......@@ -583,7 +589,7 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
phi::DenseTensor& dense_tensor = *(dist_tensor->mutable_value());
const phi::DenseTensor& dense_tensor = dist_tensor->value();
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
(!NeedTransformPlace(
dense_tensor.place(), target_args_def.backend, transform_flag) &&
......@@ -598,15 +604,13 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
}
phi::DenseTensor out = TransformData(
&dense_tensor, target_args_def, transform_flag, is_stride_kernel);
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(chenweihang): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
return std::make_shared<phi::distributed::DistTensor>(
std::make_shared<phi::DenseTensor>(std::move(out)),
dist_tensor->meta(),
dist_tensor->dist_attr());
out, dist_tensor->dist_attr());
}
return nullptr;
}
......
......@@ -69,11 +69,11 @@ INPLACE_API_OUT_CREATION_TEMPLATE = """
"""
SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput(&api_output);
auto dense_out = dist_out->mutable_value();
auto dense_out = const_cast<phi::DenseTensor*>(&dist_out->value());
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({});
auto dense_out_{} = dist_out_{}->mutable_value();
auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value());
"""
# TODO(chenweihang): support vector and tuple output later
......@@ -81,7 +81,7 @@ VECTOR_OUT_CREATION_TEMPLATE = """
"""
MULTI_VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = {}({}, {});
auto dense_out_{} = dist_out_{}->mutable_value();
auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value());
"""
TUPLE_OUT_CREATION_TEMPLATE = """
"""
......@@ -118,11 +118,11 @@ INPUT_RESHARD_TEMPLATE = """
# 5. PrepareData
SINGLE_PREPARE_DATA_TEMPLATE = """
auto dist_input_{} = PrepareDataForDistTensor({}, GetKernelInputArgDef(kernel.InputAt({}), kernel_backend), {}, kernel_result.is_stride_kernel);
auto input_{} = dist_input_{}->mutable_value();
auto input_{} = &dist_input_{}->value();
"""
INFER_META_SINGLE_INPUT_TEMPLATE = """
auto dist_input_{} = {}.impl();
auto input_{} = static_cast<phi::distributed::DistTensor*>(dist_input_{}.get())->mutable_value();
auto input_{} = &(static_cast<phi::distributed::DistTensor*>(dist_input_{}.get())->value());
"""
INFER_META_OPTIONAL_INPUT_TEMPLATE = """
paddle::optional<phi::TensorBase> input_{} = {} ? paddle::optional<phi::TensorBase>(*{}->impl()) : paddle::none;
......
......@@ -25,14 +25,14 @@ from dist_api_gen import DistForwardAPI
# 1. Create API Outputs
SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({});
auto dense_out = dist_out->mutable_value();
auto dense_out = const_cast<phi::DenseTensor*>(&dist_out->value());
"""
INPLACE_OUT_CREATION_TEMPLATE = """
*{} = {};
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({});
auto dense_out_{} = dist_out_{}->mutable_value();
auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value());
"""
......
......@@ -344,5 +344,9 @@ std::string TensorDistAttr::partial_status_string() const {
return partial_status_str;
}
bool TensorDistAttr::empty() const {
return process_mesh_.empty() || dims_mapping_.empty();
}
} // namespace distributed
} // namespace phi
......@@ -130,6 +130,8 @@ class TensorDistAttr {
void parse_from_string(const std::string& data);
bool empty() const;
private:
static std::vector<std::string> fields_;
ProcessMesh process_mesh_;
......
......@@ -17,45 +17,84 @@
namespace phi {
namespace distributed {
void* DistTensor::AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size,
bool fake_alloc) {
return value_->AllocateFrom(allocator, dtype, requested_size, fake_alloc);
inline void check_defined(const DistTensor& dist_tensor,
std::string method_hint) {
PADDLE_ENFORCE_EQ(
dist_tensor.defined(),
true,
phi::errors::Unimplemented(
"DistTensor is not defined yet when `%s` method is called.",
method_hint));
}
const Place& DistTensor::place() const {
// TODO(chenweihang): Reshard the input global value into local value
DistTensor::DistTensor(const phi::DenseTensor& global_value,
const TensorDistAttr& dist_attr)
: dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {}
DistTensor::DistTensor(const phi::DenseTensor& value,
const DDim& dims,
const TensorDistAttr& dist_attr)
: dims_(dims), dist_attr_(dist_attr), value_(value) {}
DistTensor::DistTensor(const DDim& dims, const TensorDistAttr& dist_attr)
: dims_(dims), dist_attr_(dist_attr) {}
void DistTensor::set_dims(const DDim& dims) {
PADDLE_ENFORCE_EQ(
value_ && value_->holder_,
true,
phi::errors::PreconditionNotMet(
"Tensor not initialized yet when DistTensor::place() is called."));
return value_->holder_->place();
this->initialized(),
false,
phi::errors::Unimplemented(
"DistTensor's set_dims method can only be used when the `value` "
"is not initialized (generally used in the InferMeta and "
"InferSPMD stages)."));
dims_ = dims;
}
int64_t DistTensor::numel() const {
if (meta_.is_scalar) {
return 1;
}
return product(meta_.dims);
check_defined(*this, "numel");
return value_.numel();
}
void DistTensor::set_meta(DenseTensorMeta&& meta) {
PADDLE_ENFORCE_EQ(meta_.valid(),
false,
phi::errors::InvalidArgument(
"Only when the original attribute of Tensor is "
"incomplete, can it be reset."));
meta_ = std::move(meta);
const DDim& DistTensor::local_dims() const {
check_defined(*this, "local_dims");
return value_.dims();
}
void DistTensor::set_meta(const DenseTensorMeta& meta) {
PADDLE_ENFORCE_EQ(
meta.valid(),
true,
phi::errors::InvalidArgument(
"Input meta is invalid, please check the meta attribute."));
meta_ = meta;
bool DistTensor::valid() const {
check_defined(*this, "valid");
return value_.valid();
}
bool DistTensor::defined() const { return value_.holder_ != nullptr; }
bool DistTensor::initialized() const {
return value_.holder_ != nullptr && value_.holder_->ptr();
}
DataType DistTensor::dtype() const {
check_defined(*this, "dtype");
return value_.dtype();
}
DataLayout DistTensor::layout() const {
check_defined(*this, "layout");
return value_.layout();
}
const Place& DistTensor::place() const {
check_defined(*this, "place");
return value_.holder_->place();
}
void* DistTensor::AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size,
bool fake_alloc) {
PADDLE_THROW(phi::errors::Unavailable(
"The DistTensor Cannot allocate memory directly and needs to perform "
"memory operations through its DenseTensor value."));
return nullptr;
}
} // namespace distributed
......
......@@ -15,116 +15,106 @@
#pragma once
#include <memory>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
namespace phi {
class DenseTensorUtils;
namespace distributed {
class TensorDistAttr;
class DistTensor final
: public phi::TensorBase,
public phi::TypeInfoTraits<phi::TensorBase, DistTensor> {
public:
/// \brief Construct a dist tensor and allocate space.
/// \param a The allocator used to allocate space.
/// \param meta The meta data of dist tensor.
DistTensor(Allocator* a,
const DenseTensorMeta& meta,
const std::shared_ptr<TensorDistAttr>& dist_attr)
: meta_(meta), dist_attr_(dist_attr) {
value_ = std::make_unique<DenseTensor>(a, meta);
}
DistTensor(Allocator* a,
DenseTensorMeta&& meta,
const std::shared_ptr<TensorDistAttr>& dist_attr)
: meta_(std::move(meta)), dist_attr_(dist_attr) {
value_ = std::make_unique<DenseTensor>(a, meta);
}
DistTensor(const std::shared_ptr<phi::Allocation>& holder,
const DenseTensorMeta& meta,
const std::shared_ptr<TensorDistAttr>& dist_attr)
: meta_(meta), dist_attr_(dist_attr) {
value_ = std::make_unique<DenseTensor>(holder, meta);
}
DistTensor(const std::shared_ptr<phi::DenseTensor>& dense_tensor,
const DenseTensorMeta& meta,
const std::shared_ptr<TensorDistAttr>& dist_attr)
: meta_(meta), dist_attr_(dist_attr) {
value_ = std::make_unique<DenseTensor>(*dense_tensor);
}
/// \brief Construct a dist tensor based dense tensor.
/// \param global_value The global dense tensor of the current tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const phi::DenseTensor& global_value,
const TensorDistAttr& dist_attr);
// TODO(chenweihang): Remove this constructor after added reshard impl
/// \brief Construct a dist tensor based dense tensor.
/// \param value The local dense tensor of the current tensor.
/// \param dims The global dimension of the currnet tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const phi::DenseTensor& value,
const DDim& dims,
const TensorDistAttr& dist_attr);
/// \brief Construct a empty dist tensor (for infer spmd)
/// \param dims The global dimension of the currnet Tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const DDim& dims, const TensorDistAttr& dist_attr);
/// \brief Destroy the tensor object and release exclusive resources.
~DistTensor() = default;
/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() { return "DistTensor"; }
const DenseTensor& value() const { return *value_; }
/// \brief Returns the global dims of the dist tensor.
/// \return The global dims of the dist tensor.
const DDim& dims() const override { return dims_; }
/// \brief Set the global dims of the dist tensor.
/// \return void
void set_dims(const DDim& dims);
/// \brief Returns the dist attr of current dist tensor.
/// \return The TensorDistAttr's const reference
const TensorDistAttr& dist_attr() const { return dist_attr_; }
DenseTensor* mutable_value() { return value_.get(); }
/// \brief Returns the dense tensor value's const reference in dist tensor.
/// \return The DenseTensor value's const reference
const DenseTensor& value() const { return value_; }
const std::shared_ptr<TensorDistAttr>& dist_attr() const {
return dist_attr_;
}
/// \brief Returns the global dims of the dist tensor.
/// \return The global dims of the dist tensor.
const DDim& local_dims() const;
/// \brief Returns the number of elements contained in tensor.
/// \brief Returns the global number of elements contained in tensor.
/// \return The number of elements contained in tensor.
int64_t numel() const override;
/// \brief Returns the dims of the tensor.
/// \return The dims of the tensor.
const DDim& dims() const override { return meta_.dims; }
/// \brief Test whether the dense tensor value's storage is allocated.
/// \return Whether the dense tensor value's storage is allocated.
bool initialized() const override;
/// \brief Test whether the storage is allocated.
/// \return Whether the storage is allocated.
bool initialized() const override {
return value_ && value_->holder_ && value_->holder_->ptr();
}
bool defined() const { return value_ && value_->holder_; }
/// \brief Test whether the dense tensor value is defined.
/// \return Whether the dense tensor value is defined.
bool defined() const;
/// \brief Test whether the metadata is valid.
/// \return Whether the metadata is valid.
bool valid() const override { return meta_.valid(); }
/// \brief Allocate memory with requested size from allocator.
/// \return The mutable data pointer value of type T.
void* AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size = 0,
bool fake_alloc = false) override;
bool valid() const override;
/// \brief Returns the data type of the tensor.
/// \return The data type of the tensor.
DataType dtype() const override { return meta_.dtype; }
DataType dtype() const override;
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const override { return meta_.layout; }
DataLayout layout() const override;
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
const Place& place() const override;
const DenseTensorMeta& meta() const noexcept { return meta_; }
/// \brief Sets the meta information of the tensor. Only when the original
/// attribute of Tensor is incomplete, can it be reset.
/// \param meta The meta information of the tensor.
void set_meta(DenseTensorMeta&& meta);
void set_meta(const DenseTensorMeta& meta);
/// \brief Allocate memory with requested size from allocator.
/// \return The mutable data pointer value of type T.
void* AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size = 0,
bool fake_alloc = false) override;
private:
friend class phi::DenseTensorUtils;
DenseTensorMeta meta_;
std::shared_ptr<TensorDistAttr> dist_attr_{nullptr};
std::unique_ptr<DenseTensor> value_{nullptr};
// The global dimensions(shape)
DDim dims_;
// The distributed attributes
TensorDistAttr dist_attr_;
// The local DenseTensor value
DenseTensor value_;
};
} // namespace distributed
......
......@@ -25,20 +25,19 @@
namespace phi {
namespace distributed {
bool RToSReshardFunction::IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
bool RToSReshardFunction::IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr->dims_mapping();
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
const auto& in_dims_mapping = in_dist_attr.dims_mapping();
const auto& out_dims_mapping = out_dist_attr.dims_mapping();
flag &= IsDimsMappingReplicated(in_dims_mapping);
flag &= IsDimsMappingShard(out_dims_mapping);
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& out_process_mesh = out_dist_attr->process_mesh();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();
flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
......@@ -50,9 +49,9 @@ bool RToSReshardFunction::IsSuitable(
std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
phi::DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
const auto& out_process_mesh = out_dist_attr->process_mesh();
const TensorDistAttr& out_dist_attr) {
const auto& out_dims_mapping = out_dist_attr.dims_mapping();
const auto& out_process_mesh = out_dist_attr.process_mesh();
const DenseTensor& in_physical_tensor_cur_rank = in.value();
DenseTensor out_physical_tensor_cur_rank;
......@@ -92,9 +91,7 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
<< out_physical_tensor_cur_rank.dims();
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_physical_tensor_cur_rank),
out_physical_tensor_cur_rank.meta(),
out_dist_attr);
out_physical_tensor_cur_rank, in.dims(), out_dist_attr);
}
} // namespace distributed
......
......@@ -24,14 +24,13 @@ class RToSReshardFunction final : public ReshardFunction {
RToSReshardFunction() = default;
~RToSReshardFunction() = default;
bool IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
const TensorDistAttr& out_dist_attr) override;
};
} // namespace distributed
......
......@@ -28,14 +28,13 @@ class ReshardFunction {
ReshardFunction() = default;
virtual ~ReshardFunction() = default;
virtual bool IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
virtual bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) = 0;
virtual std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
const TensorDistAttr& out_dist_attr) = 0;
};
} // namespace distributed
......
......@@ -25,20 +25,19 @@
namespace phi {
namespace distributed {
bool SToRReshardFunction::IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
bool SToRReshardFunction::IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr->dims_mapping();
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
const auto& in_dims_mapping = in_dist_attr.dims_mapping();
const auto& out_dims_mapping = out_dist_attr.dims_mapping();
flag &= IsDimsMappingShard(in_dims_mapping);
flag &= IsDimsMappingReplicated(out_dims_mapping);
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& out_process_mesh = out_dist_attr->process_mesh();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();
flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
......@@ -50,13 +49,13 @@ bool SToRReshardFunction::IsSuitable(
std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
const TensorDistAttr& out_dist_attr) {
// TODO(liyurui): Only support transfer shard(0) to replicate for now.
// Concat is needed when transfer shard(x) to replicate, will be supported
// later.
const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
// Since the precondition ensure the out_process_ids is equal to the
......@@ -66,9 +65,7 @@ std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_all_gather),
out_all_gather.meta(),
out_dist_attr);
out_all_gather, out_all_gather.dims(), out_dist_attr);
}
} // namespace distributed
......
......@@ -23,14 +23,13 @@ class SToRReshardFunction final : public ReshardFunction {
SToRReshardFunction() = default;
~SToRReshardFunction() = default;
bool IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
const TensorDistAttr& out_dist_attr) override;
};
} // namespace distributed
......
......@@ -90,9 +90,7 @@ void MetaTensor::set_dims(const DDim& dims) {
->dims = dims;
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<distributed::DistTensor*>(tensor_))
->dims = dims;
static_cast<distributed::DistTensor*>(tensor_)->set_dims(dims);
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
......@@ -127,9 +125,7 @@ void MetaTensor::set_dtype(DataType dtype) {
->dtype = dtype;
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<distributed::DistTensor*>(tensor_))
->dtype = dtype;
// skip, DistTensor no need to set dtype
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
......@@ -163,9 +159,7 @@ void MetaTensor::set_layout(DataLayout layout) {
->layout = layout;
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<distributed::DistTensor*>(tensor_))
->layout = layout;
// skip, DistTensor no need to set dtype
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
......
......@@ -21,16 +21,9 @@ limitations under the License. */
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_array.h"
#include "paddle/phi/core/tensor_meta.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
namespace phi {
// TODO(chenweihang): DenseTensorUtils has been abused during the development
// process, and now its semantics are incorrect. It can not only operate
// DenseTensors, but also other types of Tensors, requiring renaming or
// splitting
class DenseTensorUtils {
public:
static DenseTensorMeta* GetMutableMeta(DenseTensor* tensor) {
......@@ -45,12 +38,6 @@ class DenseTensorUtils {
return &(tensor->meta_);
}
#ifdef PADDLE_WITH_DISTRIBUTE
static DenseTensorMeta* GetMutableMeta(distributed::DistTensor* tensor) {
return &(tensor->meta_);
}
#endif
static const std::shared_ptr<phi::Allocation>& GetHolder(
const DenseTensor& tensor) {
return tensor.holder_;
......
......@@ -68,7 +68,7 @@ class TestReshardRToS:
else out_shape[self._shard] // 2 + 1
)
assert np.equal(out.shape, out_shape).all()
assert np.equal(out.numpy().shape, out_shape).all()
if __name__ == '__main__':
......
......@@ -14,13 +14,15 @@ limitations under the License. */
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "test/cpp/phi/core/allocator.h"
namespace phi {
namespace distributed {
namespace auto_parallel {
namespace tests {
TEST(dist_tensor, constructor) {
......@@ -32,27 +34,39 @@ TEST(dist_tensor, constructor) {
DDim dims({3, 4});
DenseTensorMeta meta(dtype, dims);
auto dist_attr = std::make_shared<TensorDistAttr>(phi::vectorize(dims));
DistTensor x1(alloc, meta, dist_attr);
EXPECT_TRUE(x1.defined());
EXPECT_TRUE(x1.initialized());
auto dist_attr = TensorDistAttr(phi::vectorize(dims));
DistTensor x2(alloc, DenseTensorMeta(dtype, dims), dist_attr);
EXPECT_TRUE(x2.defined());
EXPECT_TRUE(x2.initialized());
// copy construct
DenseTensor x1(alloc, meta);
DistTensor dist_x1(x1, dist_attr);
EXPECT_TRUE(dist_x1.defined());
EXPECT_TRUE(dist_x1.initialized());
EXPECT_TRUE(dist_x1.valid());
EXPECT_EQ(dist_x1.numel(), 12L);
EXPECT_EQ(dist_x1.local_dims()[0], 3L);
EXPECT_EQ(dist_x1.local_dims()[1], 4L);
DistTensor x3(x2.value().Holder(), meta, dist_attr);
EXPECT_TRUE(x3.defined());
EXPECT_TRUE(x3.initialized());
DenseTensor x2(alloc, meta);
DistTensor dist_x2(x2, dims, dist_attr);
EXPECT_TRUE(dist_x2.defined());
EXPECT_TRUE(dist_x2.initialized());
EXPECT_TRUE(dist_x1.valid());
auto a = std::make_shared<DenseTensor>(alloc, DenseTensorMeta(dtype, dims));
DistTensor x4(a, a->meta(), dist_attr);
EXPECT_TRUE(x4.defined());
EXPECT_TRUE(x4.initialized());
// empty construct
DistTensor dist_x3(dims, dist_attr);
EXPECT_TRUE(!dist_x3.defined());
EXPECT_TRUE(!dist_x3.initialized());
// allocate error test
bool caught_exception = false;
try {
dist_x3.AllocateFrom(alloc, phi::DataType::FLOAT32, 12L, false);
} catch (phi::EnforceNotMet& error) {
caught_exception = true;
EXPECT_NE(std::string(error.what()).find("Unavailable"), 0UL);
}
EXPECT_TRUE(caught_exception);
}
} // namespace tests
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册