提交 a2b32356 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4047 add attr 'shape' and 'dtype' and interface 'asnumpy' for Tensor

Merge pull request !4047 from zhangbuxue/add_attr_dtype_and_shape_and_interface_asnumpy_for_tensor
......@@ -268,7 +268,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
}),
py::arg("input"), py::arg("dtype") = nullptr)
.def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
.def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter(
.def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
Get the tensor's data type.
Returns:
......@@ -279,7 +279,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.dtype
Int32
)mydelimiter")
.def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter(
.def_property_readonly("_shape", TensorPy::GetPyTupleShape, R"mydelimiter(
Get the tensor's shape.
Returns:
......
......@@ -208,13 +208,41 @@ class Tensor(Tensor_):
return "Unknown Tensor type!"
return str(self.asnumpy())
@property
def shape(self):
"""The shape of tensor."""
return self._shape
@property
def dtype(self):
"""The dtype of tensor."""
return self._dtype
@property
def virtual_flag(self):
"""Mark tensor is virtual."""
return self._virtual_flag
@virtual_flag.setter
def virtual_flag(self, value):
"""The setter of virtual_flag."""
if not isinstance(value, bool):
raise TypeError("virtual_flag must be bool.")
self._virtual_flag = value
def asnumpy(self):
"""Convert tensor to numpy array."""
return Tensor_.asnumpy(self)
def all(self, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.
Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction.
Default: (), reduce all dimensions.
keep_dims (bool): Whether to keep the reduced dimensions.
Default : False, don't keep these reduced dimensions.
Returns:
Tensor, has the same data type as x.
......@@ -228,7 +256,9 @@ class Tensor(Tensor_):
Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction.
Default: (), reduce all dimensions.
keep_dims (bool): Whether to keep the reduced dimensions.
Default : False, don't keep these reduced dimensions.
Returns:
Tensor, has the same data type as x.
......@@ -236,18 +266,6 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
@property
def virtual_flag(self):
"""Mark tensor is virtual."""
return self._virtual_flag
@virtual_flag.setter
def virtual_flag(self, value):
"""The setter of virtual_flag."""
if not isinstance(value, bool):
raise TypeError("virtual_flag must be bool.")
self._virtual_flag = value
class IndexedSlices:
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册