未验证 提交 ae2d8ba1 编写于 作者: C Chen Weihang 提交者: GitHub

[AutoParallel] Simplify DistTensor namespace path (#55593)

* simplify dist tensor namespace path

* fix tensor dist attr decl error
上级 a3cf25e3
......@@ -45,7 +45,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
using phi::distributed::auto_parallel::DistTensor;
using phi::distributed::DistTensor;
using phi::distributed::auto_parallel::TensorDistAttr;
#endif
......
......@@ -801,8 +801,8 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
return ToPyObject(tensor);
} else if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
auto* tensor = static_cast<phi::distributed::auto_parallel::DistTensor*>(
self->tensor.impl().get());
auto* tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
VLOG(6) << "dist tensor: " << tensor->defined();
return ToPyObject(tensor);
#else
......
......@@ -164,9 +164,8 @@ PyObject* tensor_properties_get_dist_attr(TensorObject* self, void* closure) {
EAGER_TRY
if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
phi::distributed::auto_parallel::DistTensor* dist_tensor =
static_cast<phi::distributed::auto_parallel::DistTensor*>(
self->tensor.impl().get());
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
return ToPyObject(dist_tensor->dist_attr().get());
#else
RETURN_PY_NONE
......
......@@ -859,7 +859,7 @@ PyObject* ToPyObject(const phi::DenseTensor* value) {
}
#ifdef PADDLE_WITH_DISTRIBUTE
PyObject* ToPyObject(const phi::distributed::auto_parallel::DistTensor* value) {
PyObject* ToPyObject(const phi::distributed::DistTensor* value) {
auto obj = ::pybind11::cast(value, py::return_value_policy::reference);
obj.inc_ref();
return obj.ptr();
......
......@@ -113,7 +113,7 @@ PyObject* ToPyObject(const std::vector<std::vector<paddle::Tensor>>& value,
PyObject* ToPyObject(const platform::Place& value);
PyObject* ToPyObject(const phi::DenseTensor* value);
#ifdef PADDLE_WITH_DISTRIBUTE
PyObject* ToPyObject(const phi::distributed::auto_parallel::DistTensor* value);
PyObject* ToPyObject(const phi::distributed::DistTensor* value);
PyObject* ToPyObject(
const phi::distributed::auto_parallel::TensorDistAttr* value);
#endif
......
......@@ -1025,7 +1025,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
using phi::distributed::auto_parallel::DistTensor;
using phi::distributed::DistTensor;
py::class_<DistTensor>(m, "DistTensor")
.def(
"get_tensor",
......
......@@ -133,7 +133,7 @@ bool Tensor::is_dense_tensor() const {
}
bool Tensor::is_dist_tensor() const {
#ifdef PADDLE_WITH_DISTRIBUTE
return phi::distributed::auto_parallel::DistTensor::classof(impl_.get());
return phi::distributed::DistTensor::classof(impl_.get());
#else
return false;
#endif
......
......@@ -30,9 +30,7 @@ namespace phi {
class DenseTensorUtils;
namespace distributed {
namespace auto_parallel {
class DistTensor;
} // namespace auto_parallel
} // namespace distributed
/// \brief The Dense tensor stores values in a contiguous sequential block
......@@ -186,7 +184,7 @@ class DenseTensor : public TensorBase,
private:
friend class DenseTensorUtils;
friend class phi::distributed::auto_parallel::DistTensor;
friend class phi::distributed::DistTensor;
protected:
DenseTensorMeta meta_;
......
......@@ -16,7 +16,6 @@
namespace phi {
namespace distributed {
namespace auto_parallel {
void* DistTensor::AllocateFrom(Allocator* allocator,
DataType dtype,
......@@ -59,6 +58,5 @@ void DistTensor::set_meta(const DenseTensorMeta& meta) {
meta_ = meta;
}
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
......@@ -18,11 +18,12 @@
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace distributed {
namespace auto_parallel {
namespace auto_parallel {
class TensorDistAttr;
}
using auto_parallel::TensorDistAttr;
class DistTensor final
: public phi::TensorBase,
......@@ -125,6 +126,5 @@ class DistTensor final
std::unique_ptr<DenseTensor> value_{nullptr};
};
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
......@@ -56,8 +56,7 @@ template class TypeInfoTraits<phi::DeviceContext, CPUContext>;
template class TypeInfoTraits<phi::DeviceContext, CustomContext>;
#ifdef PADDLE_WITH_DISTRIBUTE
template class TypeInfoTraits<phi::TensorBase,
phi::distributed::auto_parallel::DistTensor>;
template class TypeInfoTraits<phi::TensorBase, phi::distributed::DistTensor>;
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册