未验证 提交 5c95e5c8 编写于 作者: O OccupyMars2025 提交者: GitHub

take some notes about sparse API (#45720)

上级 23def396
...@@ -523,8 +523,8 @@ static PyObject* eager_api_sparse_coo_tensor(PyObject* self, ...@@ -523,8 +523,8 @@ static PyObject* eager_api_sparse_coo_tensor(PyObject* self,
std::dynamic_pointer_cast<phi::DenseTensor>(non_zero_indices.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(non_zero_indices.impl());
auto dense_elements = auto dense_elements =
std::dynamic_pointer_cast<phi::DenseTensor>(non_zero_elements.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(non_zero_elements.impl());
// TODO(zhangkaihuo): After create SparseTensor, call coalesced() to sort and // TODO(zhangkaihuo): After creating SparseCooTensor, call coalesced() to sort
// merge duplicate indices // and merge duplicate indices
std::shared_ptr<phi::SparseCooTensor> coo_tensor = std::shared_ptr<phi::SparseCooTensor> coo_tensor =
std::make_shared<phi::SparseCooTensor>( std::make_shared<phi::SparseCooTensor>(
*dense_indices, *dense_elements, phi::make_ddim(dense_shape)); *dense_indices, *dense_elements, phi::make_ddim(dense_shape));
...@@ -537,7 +537,7 @@ static PyObject* eager_api_sparse_coo_tensor(PyObject* self, ...@@ -537,7 +537,7 @@ static PyObject* eager_api_sparse_coo_tensor(PyObject* self,
autograd_meta->SetStopGradient(static_cast<bool>(stop_gradient)); autograd_meta->SetStopGradient(static_cast<bool>(stop_gradient));
if (!autograd_meta->GetMutableGradNode()) { if (!autograd_meta->GetMutableGradNode()) {
VLOG(3) << "Tensor(" << name VLOG(3) << "Tensor(" << name
<< ") have not GradNode, add GradNodeAccumulation for it."; << ") doesn't have GradNode, add GradNodeAccumulation to it.";
autograd_meta->SetGradNode( autograd_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(autograd_meta)); std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
} }
......
...@@ -92,7 +92,7 @@ int TensorDtype2NumpyDtype(phi::DataType dtype) { ...@@ -92,7 +92,7 @@ int TensorDtype2NumpyDtype(phi::DataType dtype) {
} }
bool PyObject_CheckLongOrConvertToLong(PyObject** obj) { bool PyObject_CheckLongOrConvertToLong(PyObject** obj) {
if ((PyLong_Check(*obj) && !PyBool_Check(*obj))) { if (PyLong_Check(*obj) && !PyBool_Check(*obj)) {
return true; return true;
} }
...@@ -129,7 +129,7 @@ bool PyObject_CheckStr(PyObject* obj) { return PyUnicode_Check(obj); } ...@@ -129,7 +129,7 @@ bool PyObject_CheckStr(PyObject* obj) { return PyUnicode_Check(obj); }
bool CastPyArg2AttrBoolean(PyObject* obj, ssize_t arg_pos) { bool CastPyArg2AttrBoolean(PyObject* obj, ssize_t arg_pos) {
if (obj == Py_None) { if (obj == Py_None) {
return false; // To be compatible with QA integration testing. Some return false; // To be compatible with QA integration testing. Some
// test case pass in None. // test cases pass in None.
} else if (obj == Py_True) { } else if (obj == Py_True) {
return true; return true;
} else if (obj == Py_False) { } else if (obj == Py_False) {
...@@ -305,7 +305,7 @@ std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos) { ...@@ -305,7 +305,7 @@ std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GET_ITEM(obj, i);
if (PyObject_CheckLongOrConvertToLong(&item)) { if (PyObject_CheckLongOrConvertToLong(&item)) {
result.emplace_back(static_cast<int>(PyLong_AsLong(item))); result.emplace_back(static_cast<int>(PyLong_AsLong(item)));
} else { } else {
...@@ -321,13 +321,13 @@ std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos) { ...@@ -321,13 +321,13 @@ std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GET_ITEM(obj, i);
if (PyObject_CheckLongOrConvertToLong(&item)) { if (PyObject_CheckLongOrConvertToLong(&item)) {
result.emplace_back(static_cast<int>(PyLong_AsLong(item))); result.emplace_back(static_cast<int>(PyLong_AsLong(item)));
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be " "argument (position %d) must be "
"list of bool, but got %s at pos %d", "list of int, but got %s at pos %d",
arg_pos + 1, arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name, reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name,
i)); i));
......
...@@ -574,7 +574,7 @@ class PADDLE_API Tensor final { ...@@ -574,7 +574,7 @@ class PADDLE_API Tensor final {
* unified to Tensor, but Tensor itself is heterogeneous. * unified to Tensor, but Tensor itself is heterogeneous.
* *
* Tensor can generally be represented by void* and size_t, place. * Tensor can generally be represented by void* and size_t, place.
* This is suitable for most scenarios including CPU, GPU, HIP, CPU, etc., * This is suitable for most scenarios including CPU, GPU, HIP, NPU, etc.,
* but there are a few cases where this definition cannot be described, * but there are a few cases where this definition cannot be described,
* such as the Tensor representation in third-party lib such as Metal, * such as the Tensor representation in third-party lib such as Metal,
* OpenCL, etc., as well as some special Tensor implementations, including * OpenCL, etc., as well as some special Tensor implementations, including
......
...@@ -29,7 +29,7 @@ namespace phi { ...@@ -29,7 +29,7 @@ namespace phi {
class DenseTensorUtils; class DenseTensorUtils;
/// \brief The Dense tensor store values in a contiguous sequential block /// \brief The Dense tensor stores values in a contiguous sequential block
/// of memory where all values are represented. Tensors or multi-dimensional /// of memory where all values are represented. Tensors or multi-dimensional
/// arrays are used in math operators. /// arrays are used in math operators.
/// During the entire life cycle of a DenseTensor, its device type and key /// During the entire life cycle of a DenseTensor, its device type and key
......
...@@ -55,12 +55,12 @@ class TensorBase { ...@@ -55,12 +55,12 @@ class TensorBase {
virtual bool valid() const = 0; virtual bool valid() const = 0;
/// \brief Test whether the storage is allocated. /// \brief Test whether the storage is allocated.
/// return Whether the storage is allocated. /// \return Whether the storage is allocated.
virtual bool initialized() const = 0; virtual bool initialized() const = 0;
// TODO(Aurelius84): This interface is under intermediate state now. // TODO(Aurelius84): This interface is under intermediate state now.
// We will remove DataType argument in the future. Please DO NOT // We will remove DataType argument in the future. Please DO NOT
// rely on Datatype to much when design and implement other feature. // rely on Datatype too much when designing and implementing other features.
/// \brief Allocate memory with requested size from allocator. /// \brief Allocate memory with requested size from allocator.
/// \return The mutable data pointer value of type T. /// \return The mutable data pointer value of type T.
...@@ -70,7 +70,7 @@ class TensorBase { ...@@ -70,7 +70,7 @@ class TensorBase {
/// \brief Return the type information of the derived class to support /// \brief Return the type information of the derived class to support
/// safely downcast in non-rtti environment. /// safely downcast in non-rtti environment.
/// return The type information of the derived class. /// \return The type information of the derived class.
TypeInfo<TensorBase> type_info() const { return type_info_; } TypeInfo<TensorBase> type_info() const { return type_info_; }
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册