未验证 提交 ae2d8ba1 编写于 作者: C Chen Weihang 提交者: GitHub

[AutoParallel] Simplify DistTensor namespace path (#55593)

* simplify dist tensor namespace path

* fix tensor dist attr decl error
上级 a3cf25e3
...@@ -45,7 +45,7 @@ limitations under the License. */ ...@@ -45,7 +45,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
using phi::distributed::auto_parallel::DistTensor; using phi::distributed::DistTensor;
using phi::distributed::auto_parallel::TensorDistAttr; using phi::distributed::auto_parallel::TensorDistAttr;
#endif #endif
......
...@@ -801,8 +801,8 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self, ...@@ -801,8 +801,8 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
return ToPyObject(tensor); return ToPyObject(tensor);
} else if (self->tensor.is_dist_tensor()) { } else if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
auto* tensor = static_cast<phi::distributed::auto_parallel::DistTensor*>( auto* tensor =
self->tensor.impl().get()); static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
VLOG(6) << "dist tensor: " << tensor->defined(); VLOG(6) << "dist tensor: " << tensor->defined();
return ToPyObject(tensor); return ToPyObject(tensor);
#else #else
......
...@@ -164,9 +164,8 @@ PyObject* tensor_properties_get_dist_attr(TensorObject* self, void* closure) { ...@@ -164,9 +164,8 @@ PyObject* tensor_properties_get_dist_attr(TensorObject* self, void* closure) {
EAGER_TRY EAGER_TRY
if (self->tensor.is_dist_tensor()) { if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
phi::distributed::auto_parallel::DistTensor* dist_tensor = phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::auto_parallel::DistTensor*>( static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
self->tensor.impl().get());
return ToPyObject(dist_tensor->dist_attr().get()); return ToPyObject(dist_tensor->dist_attr().get());
#else #else
RETURN_PY_NONE RETURN_PY_NONE
......
...@@ -859,7 +859,7 @@ PyObject* ToPyObject(const phi::DenseTensor* value) { ...@@ -859,7 +859,7 @@ PyObject* ToPyObject(const phi::DenseTensor* value) {
} }
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
PyObject* ToPyObject(const phi::distributed::auto_parallel::DistTensor* value) { PyObject* ToPyObject(const phi::distributed::DistTensor* value) {
auto obj = ::pybind11::cast(value, py::return_value_policy::reference); auto obj = ::pybind11::cast(value, py::return_value_policy::reference);
obj.inc_ref(); obj.inc_ref();
return obj.ptr(); return obj.ptr();
......
...@@ -113,7 +113,7 @@ PyObject* ToPyObject(const std::vector<std::vector<paddle::Tensor>>& value, ...@@ -113,7 +113,7 @@ PyObject* ToPyObject(const std::vector<std::vector<paddle::Tensor>>& value,
PyObject* ToPyObject(const platform::Place& value); PyObject* ToPyObject(const platform::Place& value);
PyObject* ToPyObject(const phi::DenseTensor* value); PyObject* ToPyObject(const phi::DenseTensor* value);
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
PyObject* ToPyObject(const phi::distributed::auto_parallel::DistTensor* value); PyObject* ToPyObject(const phi::distributed::DistTensor* value);
PyObject* ToPyObject( PyObject* ToPyObject(
const phi::distributed::auto_parallel::TensorDistAttr* value); const phi::distributed::auto_parallel::TensorDistAttr* value);
#endif #endif
......
...@@ -1025,7 +1025,7 @@ void BindTensor(pybind11::module &m) { // NOLINT ...@@ -1025,7 +1025,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
#endif #endif
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
using phi::distributed::auto_parallel::DistTensor; using phi::distributed::DistTensor;
py::class_<DistTensor>(m, "DistTensor") py::class_<DistTensor>(m, "DistTensor")
.def( .def(
"get_tensor", "get_tensor",
......
...@@ -133,7 +133,7 @@ bool Tensor::is_dense_tensor() const { ...@@ -133,7 +133,7 @@ bool Tensor::is_dense_tensor() const {
} }
bool Tensor::is_dist_tensor() const { bool Tensor::is_dist_tensor() const {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
return phi::distributed::auto_parallel::DistTensor::classof(impl_.get()); return phi::distributed::DistTensor::classof(impl_.get());
#else #else
return false; return false;
#endif #endif
......
...@@ -30,9 +30,7 @@ namespace phi { ...@@ -30,9 +30,7 @@ namespace phi {
class DenseTensorUtils; class DenseTensorUtils;
namespace distributed { namespace distributed {
namespace auto_parallel {
class DistTensor; class DistTensor;
} // namespace auto_parallel
} // namespace distributed } // namespace distributed
/// \brief The Dense tensor stores values in a contiguous sequential block /// \brief The Dense tensor stores values in a contiguous sequential block
...@@ -186,7 +184,7 @@ class DenseTensor : public TensorBase, ...@@ -186,7 +184,7 @@ class DenseTensor : public TensorBase,
private: private:
friend class DenseTensorUtils; friend class DenseTensorUtils;
friend class phi::distributed::auto_parallel::DistTensor; friend class phi::distributed::DistTensor;
protected: protected:
DenseTensorMeta meta_; DenseTensorMeta meta_;
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
namespace phi { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel {
void* DistTensor::AllocateFrom(Allocator* allocator, void* DistTensor::AllocateFrom(Allocator* allocator,
DataType dtype, DataType dtype,
...@@ -59,6 +58,5 @@ void DistTensor::set_meta(const DenseTensorMeta& meta) { ...@@ -59,6 +58,5 @@ void DistTensor::set_meta(const DenseTensorMeta& meta) {
meta_ = meta; meta_ = meta;
} }
} // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -18,11 +18,12 @@ ...@@ -18,11 +18,12 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel {
namespace auto_parallel {
class TensorDistAttr; class TensorDistAttr;
}
using auto_parallel::TensorDistAttr;
class DistTensor final class DistTensor final
: public phi::TensorBase, : public phi::TensorBase,
...@@ -125,6 +126,5 @@ class DistTensor final ...@@ -125,6 +126,5 @@ class DistTensor final
std::unique_ptr<DenseTensor> value_{nullptr}; std::unique_ptr<DenseTensor> value_{nullptr};
}; };
} // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -56,8 +56,7 @@ template class TypeInfoTraits<phi::DeviceContext, CPUContext>; ...@@ -56,8 +56,7 @@ template class TypeInfoTraits<phi::DeviceContext, CPUContext>;
template class TypeInfoTraits<phi::DeviceContext, CustomContext>; template class TypeInfoTraits<phi::DeviceContext, CustomContext>;
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
template class TypeInfoTraits<phi::TensorBase, template class TypeInfoTraits<phi::TensorBase, phi::distributed::DistTensor>;
phi::distributed::auto_parallel::DistTensor>;
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册