提交 a2b32356 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4047 add attr 'shape' and 'dtype' and interface 'asnumpy' for Tensor

Merge pull request !4047 from zhangbuxue/add_attr_dtype_and_shape_and_interface_asnumpy_for_tensor
...@@ -268,7 +268,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { ...@@ -268,7 +268,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
}), }),
py::arg("input"), py::arg("dtype") = nullptr) py::arg("input"), py::arg("dtype") = nullptr)
.def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag) .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
.def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( .def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
Get the tensor's data type. Get the tensor's data type.
Returns: Returns:
...@@ -279,7 +279,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { ...@@ -279,7 +279,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.dtype >>> data.dtype
Int32 Int32
)mydelimiter") )mydelimiter")
.def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter( .def_property_readonly("_shape", TensorPy::GetPyTupleShape, R"mydelimiter(
Get the tensor's shape. Get the tensor's shape.
Returns: Returns:
......
...@@ -208,13 +208,41 @@ class Tensor(Tensor_): ...@@ -208,13 +208,41 @@ class Tensor(Tensor_):
return "Unknown Tensor type!" return "Unknown Tensor type!"
return str(self.asnumpy()) return str(self.asnumpy())
@property
def shape(self):
"""The shape of tensor."""
return self._shape
@property
def dtype(self):
"""The dtype of tensor."""
return self._dtype
@property
def virtual_flag(self):
"""Mark tensor is virtual."""
return self._virtual_flag
@virtual_flag.setter
def virtual_flag(self, value):
"""The setter of virtual_flag."""
if not isinstance(value, bool):
raise TypeError("virtual_flag must be bool.")
self._virtual_flag = value
def asnumpy(self):
"""Convert tensor to numpy array."""
return Tensor_.asnumpy(self)
def all(self, axis=(), keep_dims=False): def all(self, axis=(), keep_dims=False):
""" """
Check all array elements along a given axis evaluate to True. Check all array elements along a given axis evaluate to True.
Args: Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction. axis (Union[None, int, tuple(int)): Dimensions of reduction.
Default: (), reduce all dimensions.
keep_dims (bool): Whether to keep the reduced dimensions. keep_dims (bool): Whether to keep the reduced dimensions.
Default : False, don't keep these reduced dimensions.
Returns: Returns:
Tensor, has the same data type as x. Tensor, has the same data type as x.
...@@ -228,7 +256,9 @@ class Tensor(Tensor_): ...@@ -228,7 +256,9 @@ class Tensor(Tensor_):
Args: Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction. axis (Union[None, int, tuple(int)): Dimensions of reduction.
Default: (), reduce all dimensions.
keep_dims (bool): Whether to keep the reduced dimensions. keep_dims (bool): Whether to keep the reduced dimensions.
Default : False, don't keep these reduced dimensions.
Returns: Returns:
Tensor, has the same data type as x. Tensor, has the same data type as x.
...@@ -236,18 +266,6 @@ class Tensor(Tensor_): ...@@ -236,18 +266,6 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('any')(keep_dims)(self, axis) return tensor_operator_registry.get('any')(keep_dims)(self, axis)
@property
def virtual_flag(self):
"""Mark tensor is virtual."""
return self._virtual_flag
@virtual_flag.setter
def virtual_flag(self, value):
"""The setter of virtual_flag."""
if not isinstance(value, bool):
raise TypeError("virtual_flag must be bool.")
self._virtual_flag = value
class IndexedSlices: class IndexedSlices:
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册