未验证 提交 95fbbc5b 编写于 作者: Z zhangkaihuo 提交者: GitHub

Call sparse op from python (#40608)

* call sparse api from python
上级 a8e5c9be
......@@ -730,7 +730,7 @@ def GenerateNodeCreationCodes(
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);"
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
......@@ -767,8 +767,11 @@ def GenerateNodeCreationCodes(
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
if IsVectorTensorType(atype):
tw_name = f"api_result[{pos}]"
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys()
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
tw_name = f"api_result"
......@@ -805,8 +808,8 @@ def GenerateNodeCreationCodes(
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result[{pos}]);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result[{pos}], {pos});"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
......@@ -934,7 +937,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
returns_list[0] = f"api_result"
else:
# Tuple api_result
returns_list[pos] = f"api_result[{pos}]"
returns_list[pos] = f"std::get<{pos}>(api_result)"
if IsPlainTensorType(rtype):
returns_type_list[pos] = "paddle::experimental::Tensor"
......@@ -1084,7 +1087,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/phi/api/backward/sparse_bw_api.h"
"""
file_contents += node_definition_str
with open(filepath, 'a') as f:
......
......@@ -337,7 +337,7 @@ class PythonCSingleFunctionGenerator:
"paddle::experimental::", namespace, forward_api_name)
else:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"", namespace, GetForwardFunctionName(forward_api_name))
"::", namespace, GetForwardFunctionName(forward_api_name))
# Generate Record Event for performance profiling
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
......
......@@ -36,6 +36,8 @@ limitations under the License. */
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace paddle {
namespace pybind {
......@@ -718,6 +720,98 @@ static PyObject* set_grad_type(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_get_non_zero_indices(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE(self->tensor.is_sparse_coo_tensor(),
paddle::platform::errors::Fatal(
"this method is only effective for SparseCooTensor"));
auto sparse_coo_tensor =
std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
paddle::experimental::Tensor tensor(std::make_shared<phi::DenseTensor>(
sparse_coo_tensor->non_zero_indices()));
return ToPyObject(tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_get_non_zero_elements(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE(
self->tensor.is_sparse_coo_tensor() ||
self->tensor.is_sparse_csr_tensor(),
paddle::platform::errors::Fatal("this method is only effective for "
"SparseCooTensor or SparseCsrTensor"));
if (self->tensor.is_sparse_coo_tensor()) {
auto sparse_coo_tensor =
std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
paddle::experimental::Tensor tensor(std::make_shared<phi::DenseTensor>(
sparse_coo_tensor->non_zero_elements()));
return ToPyObject(tensor);
} else {
auto sparse_csr_tensor =
std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
paddle::experimental::Tensor tensor(std::make_shared<phi::DenseTensor>(
sparse_csr_tensor->non_zero_elements()));
return ToPyObject(tensor);
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_get_non_zero_crows(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
paddle::platform::errors::Fatal(
"this method is only effective for SparseCsrTensor"));
auto sparse_csr_tensor =
std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
paddle::experimental::Tensor tensor(
std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_crows()));
return ToPyObject(tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
paddle::platform::errors::Fatal(
"this method is only effective for SparseCsrTensor"));
auto sparse_csr_tensor =
std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
paddle::experimental::Tensor tensor(
std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_cols()));
return ToPyObject(tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
return ToPyObject(self->tensor.is_sparse_coo_tensor() ||
self->tensor.is_sparse_csr_tensor());
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
return ToPyObject(self->tensor.is_sparse_coo_tensor());
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_sparse_csr(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
return ToPyObject(self->tensor.is_sparse_csr_tensor());
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
......@@ -775,6 +869,26 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"_set_grad_type", (PyCFunction)(void (*)(void))set_grad_type,
METH_VARARGS | METH_KEYWORDS, NULL},
/***the method of sparse tensor****/
{"non_zero_indices",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
METH_VARARGS | METH_KEYWORDS, NULL},
{"non_zero_elements",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_elements,
METH_VARARGS | METH_KEYWORDS, NULL},
{"non_zero_crows",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_crows,
METH_VARARGS | METH_KEYWORDS, NULL},
{"non_zero_cols",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_cols,
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_sparse", (PyCFunction)(void (*)(void))tensor_method_is_sparse,
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_sparse_coo", (PyCFunction)(void (*)(void))tensor_method_is_sparse_coo,
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_sparse_csr", (PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
METH_VARARGS | METH_KEYWORDS, NULL},
/***the method of sparse tensor****/
{"_inplace_version", (PyCFunction)(void (*)(void))tensor__inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL}};
......
......@@ -225,6 +225,22 @@ class PADDLE_API Tensor final {
*/
bool is_selected_rows() const;
/**
* @brief Determine whether tensor is SparseCooTensor
*
* @return true
* @return false
*/
bool is_sparse_coo_tensor() const;
/**
* @brief Determine whether tensor is SparseCsrTensor
*
* @return true
* @return false
*/
bool is_sparse_csr_tensor() const;
/* Part 3: Device and Backend methods */
/**
......
......@@ -25,25 +25,24 @@ namespace paddle {
namespace experimental {
namespace sparse {
Tensor to_sparse_coo_impl(const Tensor& x,
Backend backend,
const int64_t sparse_dim) {
Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) {
if (x.layout() == phi::DataLayout::SPARSE_COO) {
return x;
}
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_coo";
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
kernel_name = "sparse_csr_to_coo";
}
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
VLOG(6) << "to API kernel key: " << kernel_key;
VLOG(6) << "add API kernel key: " << kernel_key;
VLOG(6) << "to API kernel: " << kernel;
// 2. Get Device Context
......@@ -62,18 +61,18 @@ Tensor to_sparse_coo_impl(const Tensor& x,
// 4. InferMeta
auto indices_meta =
phi::DenseTensorMeta(phi::DataType::INT64, {-1}, phi::DataLayout::NCHW);
auto elements_meta = phi::DenseTensorMeta(x.dtype(), {-1}, x.layout());
phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());
// 5. Prepare outputs
// create empty SparseCooTensor
phi::DenseTensor non_zero_indices(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_key.backend())),
std::move(indices_meta));
phi::DenseTensor non_zero_elements(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_key.backend())),
std::move(elements_meta));
auto coo = std::make_shared<phi::SparseCooTensor>(
non_zero_indices, non_zero_elements, x.dims());
......@@ -88,23 +87,23 @@ Tensor to_sparse_coo_impl(const Tensor& x,
return out;
}
Tensor to_sparse_csr_impl(const Tensor& x, Backend backend) {
Tensor to_sparse_csr_impl(const Tensor& x) {
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
return x;
}
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_csr";
if (x.layout() == phi::DataLayout::SPARSE_COO) {
kernel_name = "sparse_coo_to_csr";
}
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
VLOG(6) << "to API kernel key: " << kernel_key;
VLOG(6) << "add API kernel key: " << kernel_key;
VLOG(6) << "to API kernel: " << kernel;
// 2. Get Device Context
......@@ -122,24 +121,24 @@ Tensor to_sparse_csr_impl(const Tensor& x, Backend backend) {
// 4. InferMeta
auto crows_meta =
phi::DenseTensorMeta(phi::DataType::INT64, {-1}, phi::DataLayout::NCHW);
phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
auto cols_meta =
phi::DenseTensorMeta(phi::DataType::INT64, {-1}, phi::DataLayout::NCHW);
auto elements_meta = phi::DenseTensorMeta(x.dtype(), {-1}, x.layout());
phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());
// 5. Prepare outputs
// create empty SparseCooTensor
phi::DenseTensor non_zero_crows(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_key.backend())),
std::move(crows_meta));
phi::DenseTensor non_zero_cols(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_key.backend())),
std::move(cols_meta));
phi::DenseTensor non_zero_elements(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_key.backend())),
std::move(elements_meta));
auto csr = std::make_shared<phi::SparseCsrTensor>(
non_zero_crows, non_zero_cols, non_zero_elements, x.dims());
......@@ -154,24 +153,25 @@ Tensor to_sparse_csr_impl(const Tensor& x, Backend backend) {
return out;
}
Tensor to_dense_impl(const Tensor& x, Backend backend) {
Tensor to_dense_impl(const Tensor& x) {
if (x.layout() != phi::DataLayout::SPARSE_CSR &&
x.layout() != phi::DataLayout::SPARSE_COO) {
return x;
}
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "sparse_coo_to_dense";
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
kernel_name = "sparse_csr_to_dense";
}
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
VLOG(6) << "to API kernel key: " << kernel_key;
VLOG(6) << "add API kernel key: " << kernel_key;
VLOG(6) << "to API kernel: " << kernel;
// 2. Get Device Context
......@@ -194,7 +194,7 @@ Tensor to_dense_impl(const Tensor& x, Backend backend) {
// create empty SparseCooTensor
auto dense_out = std::make_shared<phi::DenseTensor>(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_key.backend())),
std::move(dense_meta));
kernel_context.EmplaceBackOutput(dense_out.get());
......
......@@ -21,13 +21,11 @@ namespace paddle {
namespace experimental {
namespace sparse {
Tensor to_dense_impl(const Tensor& x, Backend backend);
Tensor to_dense_impl(const Tensor& x);
Tensor to_sparse_coo_impl(const Tensor& x,
Backend backend,
const int64_t sparse_dim);
Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim);
Tensor to_sparse_csr_impl(const Tensor& x, Backend backend);
Tensor to_sparse_csr_impl(const Tensor& x);
} // namespace sparse
} // namespace experimental
......
......@@ -25,6 +25,8 @@ limitations under the License. */
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -132,6 +134,12 @@ bool Tensor::is_dense_tensor() const {
bool Tensor::is_selected_rows() const {
return phi::SelectedRows::classof(impl_.get());
}
bool Tensor::is_sparse_coo_tensor() const {
return phi::SparseCooTensor::classof(impl_.get());
}
bool Tensor::is_sparse_csr_tensor() const {
return phi::SparseCsrTensor::classof(impl_.get());
}
/* Part 3: Device and Backend methods */
PlaceType Tensor::place() const {
......
......@@ -53,8 +53,7 @@ TEST(API, to_sparse_coo) {
// 1. test dense_to_sparse_coo
paddle::experimental::Tensor x(dense_x);
auto out = paddle::experimental::sparse::to_sparse_coo(
x, phi::Backend::CPU, sparse_dim);
auto out = paddle::experimental::sparse::to_sparse_coo(x, sparse_dim);
auto coo = std::dynamic_pointer_cast<phi::SparseCooTensor>(out.impl());
ASSERT_EQ(coo->nnz(), non_zero_num);
int cmp_indices = memcmp(coo->non_zero_indices().data<int64_t>(),
......@@ -91,8 +90,7 @@ TEST(API, to_sparse_coo) {
auto csr =
std::make_shared<phi::SparseCsrTensor>(crows, cols, values, dense_dims);
paddle::experimental::Tensor csr_x(csr);
auto out2 = paddle::experimental::sparse::to_sparse_coo(
csr_x, phi::Backend::CPU, sparse_dim);
auto out2 = paddle::experimental::sparse::to_sparse_coo(csr_x, sparse_dim);
auto coo2 = std::dynamic_pointer_cast<phi::SparseCooTensor>(out.impl());
ASSERT_EQ(coo2->nnz(), non_zero_num);
......@@ -132,7 +130,7 @@ TEST(API, to_sparse_csr) {
// 1. test dense_to_sparse_csr
paddle::experimental::Tensor x(dense_x);
auto out = paddle::experimental::sparse::to_sparse_csr(x, phi::Backend::CPU);
auto out = paddle::experimental::sparse::to_sparse_csr(x);
auto csr = std::dynamic_pointer_cast<phi::SparseCsrTensor>(out.impl());
auto check = [&](const phi::SparseCsrTensor& csr) {
ASSERT_EQ(csr.non_zero_cols().numel(), non_zero_num);
......@@ -170,8 +168,7 @@ TEST(API, to_sparse_csr) {
auto coo =
std::make_shared<phi::SparseCooTensor>(indices, values, dense_dims);
paddle::experimental::Tensor coo_x(coo);
auto out2 =
paddle::experimental::sparse::to_sparse_csr(coo_x, phi::Backend::CPU);
auto out2 = paddle::experimental::sparse::to_sparse_csr(coo_x);
auto csr2 = std::dynamic_pointer_cast<phi::SparseCsrTensor>(out.impl());
check(*csr2);
......@@ -212,7 +209,7 @@ TEST(API, to_dense) {
std::make_shared<phi::SparseCooTensor>(indices, values, dense_dims);
paddle::experimental::Tensor coo_x(coo);
auto out = paddle::experimental::sparse::to_dense(coo_x, phi::Backend::CPU);
auto out = paddle::experimental::sparse::to_dense(coo_x);
auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(out.impl());
int cmp1 =
memcmp(dense_out->data<float>(), &dense_data[0][0], 9 * sizeof(float));
......@@ -237,7 +234,7 @@ TEST(API, to_dense) {
auto csr =
std::make_shared<phi::SparseCsrTensor>(crows, cols, values, dense_dims);
paddle::experimental::Tensor csr_x(csr);
auto out2 = paddle::experimental::sparse::to_dense(csr_x, phi::Backend::CPU);
auto out2 = paddle::experimental::sparse::to_dense(csr_x);
auto dense_out2 = std::dynamic_pointer_cast<phi::DenseTensor>(out.impl());
int cmp2 =
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle import _C_ops
from paddle.fluid.framework import _test_eager_guard
class TestSparseUtils(unittest.TestCase):
def test_to_sparse_coo(self):
with _test_eager_guard():
x = [[0, 1, 0, 2], [0, 0, 3, 0], [4, 5, 0, 0]]
non_zero_indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
non_zero_elements = [1, 2, 3, 4, 5]
dense_x = paddle.to_tensor(x)
#TODO(zhangkaihuo): change to test the corresponding API
out = _C_ops.final_state_to_sparse_coo(dense_x, 2)
print(out)
assert np.array_equal(out.non_zero_indices().numpy(),
non_zero_indices)
assert np.array_equal(out.non_zero_elements().numpy(),
non_zero_elements)
dense_tensor = _C_ops.final_state_to_dense(out)
assert np.array_equal(dense_tensor.numpy(), x)
def test_to_sparse_csr(self):
with _test_eager_guard():
x = [[0, 1, 0, 2], [0, 0, 3, 0], [4, 5, 0, 0]]
non_zero_crows = [0, 2, 3, 5]
non_zero_cols = [1, 3, 2, 0, 1]
non_zero_elements = [1, 2, 3, 4, 5]
dense_x = paddle.to_tensor(x)
out = _C_ops.final_state_to_sparse_csr(dense_x)
print(out)
assert np.array_equal(out.non_zero_crows().numpy(), non_zero_crows)
assert np.array_equal(out.non_zero_cols().numpy(), non_zero_cols)
assert np.array_equal(out.non_zero_elements().numpy(),
non_zero_elements)
dense_tensor = _C_ops.final_state_to_dense(out)
assert np.array_equal(dense_tensor.numpy(), x)
if __name__ == "__main__":
unittest.main()
......@@ -263,14 +263,7 @@ def to_string(var, prefix='Tensor'):
data=data)
def tensor_to_string(tensor, prefix='Tensor'):
indent = len(prefix) + 1
_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
if not tensor._is_initialized():
return "Tensor(Not initialized)"
def _format_dense_tensor(tensor, indent):
np_tensor = tensor.numpy()
if len(tensor.shape) == 0:
......@@ -288,6 +281,26 @@ def tensor_to_string(tensor, prefix='Tensor'):
data = _format_tensor(
np_tensor, sumary, indent=indent, max_width=max_width, signed=signed)
return data
def sparse_tensor_to_string(tensor, prefix='Tensor'):
indent = len(prefix) + 1
_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient}, \n{indent}{data})"
if tensor.is_sparse_coo():
indices_tensor = tensor.non_zero_indices()
elements_tensor = tensor.non_zero_elements()
indices_data = _format_dense_tensor(indices_tensor, indent)
elements_data = _format_dense_tensor(elements_tensor, indent)
data = 'non_zero_indices=' + indices_data + ',\nnon_zero_elements=' + elements_data
else:
crows_tensor = tensor.non_zero_crows()
cols_tensor = tensor.non_zero_cols()
elements_tensor = tensor.non_zero_elements()
crows_data = _format_dense_tensor(crows_tensor, indent)
cols_data = _format_dense_tensor(cols_tensor, indent)
elements_data = _format_dense_tensor(elements_tensor, indent)
data = 'non_zero_crows=' + crows_data + ',\nnon_zero_cols=' + cols_data + ',\nnon_zero_elements=' + elements_data
return _template.format(
prefix=prefix,
......@@ -297,3 +310,25 @@ def tensor_to_string(tensor, prefix='Tensor'):
stop_gradient=tensor.stop_gradient,
indent=' ' * indent,
data=data)
def tensor_to_string(tensor, prefix='Tensor'):
indent = len(prefix) + 1
_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
if not tensor._is_initialized():
return "Tensor(Not initialized)"
if tensor.is_sparse():
return sparse_tensor_to_string(tensor, prefix)
else:
data = _format_dense_tensor(tensor, indent)
return _template.format(
prefix=prefix,
shape=tensor.shape,
dtype=tensor.dtype,
place=tensor._place_str,
stop_gradient=tensor.stop_gradient,
indent=' ' * indent,
data=data)
......@@ -4,18 +4,19 @@
kernel :
func : sparse_conv3d
layout : x
backward : conv3d_grad
- api : to_dense
args : (Tensor x, Backend backend)
args : (Tensor x)
output : Tensor(out@DenseTensor)
invoke : to_dense_impl(x, backend)
invoke : to_dense_impl(x)
- api : to_sparse_coo
args : (Tensor x, Backend backend, int64 sparse_dim)
args : (Tensor x, int64 sparse_dim)
output : Tensor(out@SparseCooTensor)
invoke : to_sparse_coo_impl(x, backend, sparse_dim)
invoke : to_sparse_coo_impl(x, sparse_dim)
- api : to_sparse_csr
args : (Tensor x, Backend backend)
args : (Tensor x)
output : Tensor(out@SparseCsrTensor)
invoke : to_sparse_csr_impl(x, backend)
invoke : to_sparse_csr_impl(x)
......@@ -192,9 +192,7 @@ def source_include(header_file_path):
def api_register():
return """
PD_REGISTER_API(Test);
"""
return ""
def api_namespace():
......
......@@ -115,9 +115,7 @@ def source_include(header_file_path):
def api_register():
return """
PD_REGISTER_API(Test);
"""
return ""
def api_namespace():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册