Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
59b13586
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
59b13586
编写于
9月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5616 Fix problems in Tensor.from_numpy()
Merge pull request !5616 from hewei/fix_tensor_from_numpy2
上级
fa459377
cbfd4c5f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
39 deletion
+22
-39
mindspore/ccsrc/pybind_api/ir/tensor_py.cc
mindspore/ccsrc/pybind_api/ir/tensor_py.cc
+20
-39
mindspore/core/ir/tensor.h
mindspore/core/ir/tensor.h
+2
-0
未找到文件。
mindspore/ccsrc/pybind_api/ir/tensor_py.cc
浏览文件 @
59b13586
...
...
@@ -120,51 +120,24 @@ static bool IsCContiguous(const py::array &input) {
// TensorDataNumpy implements TensorData using numpy array.
class
TensorDataNumpy
:
public
TensorData
{
public:
explicit
TensorDataNumpy
(
const
py
::
array
&
input
)
:
data_
(
input
)
{
if
(
!
IsCContiguous
(
data_
))
{
// Call numpy.ascontiguousarray() to convert data to C contiguous if it is not.
auto
np
=
py
::
module
::
import
(
"numpy"
);
auto
convert
=
np
.
attr
(
"ascontiguousarray"
);
data_
=
convert
(
data_
);
}
}
explicit
TensorDataNumpy
(
py
::
buffer_info
&&
buffer
)
:
buffer_
(
std
::
move
(
buffer
))
{}
/// Total number of elements.
ssize_t
size
()
const
override
{
return
data_
.
size
()
;
}
ssize_t
size
()
const
override
{
return
buffer_
.
size
;
}
/// Byte size of a single element.
ssize_t
itemsize
()
const
override
{
return
data_
.
itemsize
()
;
}
ssize_t
itemsize
()
const
override
{
return
buffer_
.
itemsize
;
}
/// Total number of bytes.
ssize_t
nbytes
()
const
override
{
return
data_
.
nbytes
()
;
}
ssize_t
nbytes
()
const
override
{
return
buffer_
.
itemsize
*
buffer_
.
size
;
}
/// Number of dimensions.
ssize_t
ndim
()
const
override
{
return
data_
.
ndim
()
;
}
ssize_t
ndim
()
const
override
{
return
buffer_
.
ndim
;
}
/// Data pointer.
void
*
data
()
override
{
return
data_
.
request
().
ptr
;
}
const
void
*
const_data
()
const
override
{
return
data_
.
request
().
ptr
;
}
/// Is data equals.
bool
equals
(
const
TensorData
&
other
)
const
override
{
auto
ptr
=
dynamic_cast
<
const
TensorDataNumpy
*>
(
&
other
);
if
(
ptr
==
nullptr
)
{
// Not same type, compare data byte by byte.
return
TensorData
::
equals
(
other
);
}
return
NumpyEquals
(
*
ptr
);
}
void
*
data
()
override
{
return
buffer_
.
ptr
;
}
bool
NumpyEquals
(
const
TensorDataNumpy
&
other
)
const
{
auto
all_data_equal
=
[
&
other
,
this
]()
->
bool
{
auto
np
=
py
::
module
::
import
(
"numpy"
);
auto
equal
=
np
.
attr
(
"equal"
)(
data_
,
other
.
data_
);
auto
all_equal
=
np
.
attr
(
"all"
)(
equal
);
return
all_equal
.
cast
<
bool
>
();
};
return
this
==
&
other
||
data_
.
is
(
other
.
data_
)
||
all_data_equal
();
}
const
void
*
const_data
()
const
override
{
return
buffer_
.
ptr
;
}
/// To string.
std
::
string
ToString
(
const
TypeId
type
,
const
ShapeVector
&
shape
,
bool
use_comma
)
const
override
{
...
...
@@ -174,17 +147,21 @@ class TensorDataNumpy : public TensorData {
kwargs
[
"separator"
]
=
", "
;
auto
np
=
py
::
module
::
import
(
"numpy"
);
auto
array2string
=
np
.
attr
(
"array2string"
);
return
py
::
str
(
array2string
(
data_
,
**
kwargs
));
return
py
::
str
(
array2string
(
py_array
()
,
**
kwargs
));
}
// without comma.
return
py
::
str
(
data_
);
return
py
::
str
(
py_array
()
);
}
/// py::array object.
py
::
array
py_array
()
const
{
return
data_
;
}
py
::
array
py_array
()
const
{
// Use dummy owner to avoid copy data.
py
::
str
dummyOwner
;
return
py
::
array
(
py
::
dtype
(
buffer_
),
buffer_
.
shape
,
buffer_
.
strides
,
buffer_
.
ptr
,
dummyOwner
);
}
private:
mutable
py
::
array
data
_
;
py
::
buffer_info
buffer
_
;
};
TensorPtr
TensorPy
::
MakeTensor
(
const
py
::
array
&
input
,
const
TypePtr
&
type_ptr
)
{
...
...
@@ -226,6 +203,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
/// Creates a Tensor from a numpy array without copy
TensorPtr
TensorPy
::
MakeTensorNoCopy
(
const
py
::
array
&
input
)
{
// Check format.
if
(
!
IsCContiguous
(
input
))
{
MS_LOG
(
EXCEPTION
)
<<
"Array should be C contiguous."
;
}
// Get input buffer info.
py
::
buffer_info
buf
=
input
.
request
();
// Get tensor dtype and check it.
...
...
@@ -236,7 +217,7 @@ TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
// Get tensor shape.
ShapeVector
shape
(
buf
.
shape
.
begin
(),
buf
.
shape
.
end
());
// Make a tensor with shared data with numpy array.
auto
tensor_data
=
std
::
make_shared
<
TensorDataNumpy
>
(
input
);
auto
tensor_data
=
std
::
make_shared
<
TensorDataNumpy
>
(
std
::
move
(
buf
)
);
return
std
::
make_shared
<
Tensor
>
(
dtype
,
shape
,
tensor_data
);
}
...
...
mindspore/core/ir/tensor.h
浏览文件 @
59b13586
...
...
@@ -42,6 +42,8 @@ namespace tensor {
// Tensor data interface.
class
TensorData
{
public:
/// virtual destructor is required for base classes.
virtual
~
TensorData
()
=
default
;
/// Total number of elements.
virtual
ssize_t
size
()
const
=
0
;
/// Byte size of a single element.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录