Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
41009707
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41009707
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!5330 Add comma seperator for python tensor __repr__().
Merge pull request !5330 from ZhangQinghua/master
上级
cd5d8caa
a91e9fe0
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
61 addition
and
30 deletion
+61
-30
mindspore/ccsrc/utils/tensorprint_utils.cc
mindspore/ccsrc/utils/tensorprint_utils.cc
+1
-1
mindspore/core/ir/tensor.cc
mindspore/core/ir/tensor.cc
+55
-28
mindspore/core/ir/tensor.h
mindspore/core/ir/tensor.h
+5
-1
未找到文件。
mindspore/ccsrc/utils/tensorprint_utils.cc
浏览文件 @
41009707
...
...
@@ -210,7 +210,7 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
mindspore
::
tensor
::
Tensor
print_tensor
(
type_id
,
tensor_shape
);
auto
memory_size
=
totaldims
*
type_size_map
[
item
.
tensorType_
];
if
(
PrintTensorToString
(
str_data_ptr
->
data
(),
&
print_tensor
,
memory_size
))
{
buf
<<
print_tensor
.
ToString
Repr
()
<<
std
::
endl
;
buf
<<
print_tensor
.
ToString
NoLimit
()
<<
std
::
endl
;
}
}
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/core/ir/tensor.cc
浏览文件 @
41009707
...
...
@@ -213,7 +213,7 @@ class TensorDataImpl : public TensorData {
std
::
equal
(
data_
.
get
(),
data_
.
get
()
+
data_size_
,
ptr
->
data_
.
get
());
}
std
::
string
ToString
(
const
TypeId
type
,
const
ShapeVector
&
shape
)
const
override
{
std
::
string
ToString
(
const
TypeId
type
,
const
ShapeVector
&
shape
,
bool
use_comma
)
const
override
{
constexpr
auto
valid
=
std
::
is_same
<
T
,
bool
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
||
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
int16_t
>::
value
||
std
::
is_same
<
T
,
int32_t
>::
value
||
std
::
is_same
<
T
,
int64_t
>::
value
||
...
...
@@ -229,16 +229,16 @@ class TensorDataImpl : public TensorData {
std
::
ostringstream
ss
;
if
(
data_size_
==
1
&&
ndim_
==
0
)
{
// Scalar
OutputDataString
(
ss
,
0
,
0
,
1
);
OutputDataString
(
ss
,
0
,
0
,
1
,
false
);
return
ss
.
str
();
}
ssize_t
cursor
=
0
;
SummaryStringRecursive
(
ss
,
shape
,
&
cursor
,
0
);
SummaryStringRecursive
(
ss
,
shape
,
&
cursor
,
0
,
use_comma
);
return
ss
.
str
();
}
private:
void
OutputDataString
(
std
::
ostringstream
&
ss
,
ssize_t
cursor
,
ssize_t
start
,
ssize_t
end
)
const
{
void
OutputDataString
(
std
::
ostringstream
&
ss
,
ssize_t
cursor
,
ssize_t
start
,
ssize_t
end
,
bool
use_comma
)
const
{
const
bool
isScalar
=
ndim_
==
0
&&
end
-
start
==
1
;
constexpr
auto
isFloat
=
std
::
is_same
<
T
,
float16
>::
value
||
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
;
...
...
@@ -265,33 +265,43 @@ class TensorDataImpl : public TensorData {
ss
<<
std
::
setw
(
5
)
<<
std
::
setiosflags
(
std
::
ios
::
right
)
<<
(
value
?
"True"
:
"False"
);
}
}
else
{
constexpr
auto
isSigned
=
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
int16_t
>::
value
||
std
::
is_same
<
T
,
int32_t
>::
value
||
std
::
is_same
<
T
,
int64_t
>::
value
;
constexpr
auto
isSigned
=
std
::
is_same
<
T
,
int64_t
>::
value
;
if
constexpr
(
isSigned
)
{
if
(
!
isScalar
&&
static_cast
<
int64_t
>
(
value
)
>=
0
)
{
ss
<<
' '
;
}
}
// Set width and indent for different int type.
// Set width and indent for different int type
with signed position
.
//
// int8/uint8 width: 3
// int16/uint16 width: 5
// int32/uint32 width: 10
// int64/uint64 width: NOT SET
if
constexpr
(
std
::
is_same
<
T
,
int8_t
>::
value
)
{
ss
<<
std
::
setw
(
3
)
<<
std
::
setiosflags
(
std
::
ios
::
right
)
<<
static_cast
<
int16_t
>
(
value
);
}
else
if
constexpr
(
std
::
is_same
<
T
,
uint8_t
>::
value
)
{
// uint8 width: 3, [0, 255]
// int8 width: 4, [-128, 127]
// uint16 width: 5, [0, 65535]
// int16 width: 6, [-32768, 32767]
// uint32 width: 10, [0, 4294967295]
// int32 width: 11, [-2147483648, 2147483647]
// uint64 width: NOT SET (20, [0, 18446744073709551615])
// int64 width: NOT SET (20, [-9223372036854775808, 9223372036854775807])
if
constexpr
(
std
::
is_same
<
T
,
uint8_t
>::
value
)
{
ss
<<
std
::
setw
(
3
)
<<
std
::
setiosflags
(
std
::
ios
::
right
)
<<
static_cast
<
uint16_t
>
(
value
);
}
else
if
constexpr
(
std
::
is_same
<
T
,
int16_t
>::
value
||
std
::
is_same
<
T
,
uint16_t
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
T
,
int8_t
>::
value
)
{
ss
<<
std
::
setw
(
4
)
<<
std
::
setiosflags
(
std
::
ios
::
right
)
<<
static_cast
<
int16_t
>
(
value
);
}
else
if
constexpr
(
std
::
is_same
<
T
,
uint16_t
>::
value
)
{
ss
<<
std
::
setw
(
5
)
<<
std
::
setiosflags
(
std
::
ios
::
right
)
<<
value
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
int32_t
>::
value
||
std
::
is_same
<
T
,
uint32_t
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
T
,
int16_t
>::
value
)
{
ss
<<
std
::
setw
(
6
)
<<
std
::
setiosflags
(
std
::
ios
::
right
)
<<
value
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
uint32_t
>::
value
)
{
ss
<<
std
::
setw
(
10
)
<<
std
::
setiosflags
(
std
::
ios
::
right
)
<<
value
;
}
else
if
constexpr
(
std
::
is_same
<
T
,
int32_t
>::
value
)
{
ss
<<
std
::
setw
(
11
)
<<
std
::
setiosflags
(
std
::
ios
::
right
)
<<
value
;
}
else
{
ss
<<
value
;
}
}
if
(
!
isScalar
&&
i
!=
end
-
1
)
{
if
(
use_comma
)
{
ss
<<
','
;
}
ss
<<
' '
;
}
if
(
!
isScalar
&&
ndim_
==
1
&&
(
i
+
1
)
%
linefeedThreshold
==
0
)
{
...
...
@@ -301,7 +311,8 @@ class TensorDataImpl : public TensorData {
}
}
void
SummaryStringRecursive
(
std
::
ostringstream
&
ss
,
const
ShapeVector
&
shape
,
ssize_t
*
cursor
,
ssize_t
depth
)
const
{
void
SummaryStringRecursive
(
std
::
ostringstream
&
ss
,
const
ShapeVector
&
shape
,
ssize_t
*
cursor
,
ssize_t
depth
,
bool
use_comma
)
const
{
if
(
depth
>=
static_cast
<
ssize_t
>
(
ndim_
))
{
return
;
}
...
...
@@ -309,11 +320,11 @@ class TensorDataImpl : public TensorData {
if
(
depth
==
static_cast
<
ssize_t
>
(
ndim_
)
-
1
)
{
// Bottom dimension
ssize_t
num
=
shape
[
depth
];
if
(
num
>
kThreshold
&&
ndim_
>
1
)
{
OutputDataString
(
ss
,
*
cursor
,
0
,
kThreshold
/
2
);
OutputDataString
(
ss
,
*
cursor
,
0
,
kThreshold
/
2
,
use_comma
);
ss
<<
' '
<<
kEllipsis
<<
' '
;
OutputDataString
(
ss
,
*
cursor
,
num
-
kThreshold
/
2
,
num
);
OutputDataString
(
ss
,
*
cursor
,
num
-
kThreshold
/
2
,
num
,
use_comma
);
}
else
{
OutputDataString
(
ss
,
*
cursor
,
0
,
num
);
OutputDataString
(
ss
,
*
cursor
,
0
,
num
,
use_comma
);
}
*
cursor
+=
num
;
}
else
{
// Middle dimension
...
...
@@ -321,13 +332,19 @@ class TensorDataImpl : public TensorData {
// Handle the first half.
for
(
ssize_t
i
=
0
;
i
<
std
::
min
(
static_cast
<
ssize_t
>
(
kThreshold
/
2
),
num
);
i
++
)
{
if
(
i
>
0
)
{
if
(
use_comma
)
{
ss
<<
','
;
}
ss
<<
'\n'
;
ss
<<
std
::
setw
(
depth
+
1
)
<<
' '
;
// Add the indent.
}
SummaryStringRecursive
(
ss
,
shape
,
cursor
,
depth
+
1
);
SummaryStringRecursive
(
ss
,
shape
,
cursor
,
depth
+
1
,
use_comma
);
}
// Handle the ignored part.
if
(
num
>
kThreshold
)
{
if
(
use_comma
)
{
ss
<<
','
;
}
ss
<<
'\n'
;
ss
<<
std
::
setw
(
depth
+
1
)
<<
' '
;
// Add the indent.
ss
<<
kEllipsis
;
...
...
@@ -343,10 +360,14 @@ class TensorDataImpl : public TensorData {
}
// Handle the second half.
if
(
num
>
kThreshold
/
2
)
{
for
(
ssize_t
i
=
num
-
kThreshold
/
2
;
i
<
num
;
i
++
)
{
auto
continue_pos
=
num
-
kThreshold
/
2
;
for
(
ssize_t
i
=
continue_pos
;
i
<
num
;
i
++
)
{
if
(
use_comma
&&
i
!=
continue_pos
)
{
ss
<<
','
;
}
ss
<<
'\n'
;
ss
<<
std
::
setw
(
depth
+
1
)
<<
' '
;
// Add the indent.
SummaryStringRecursive
(
ss
,
shape
,
cursor
,
depth
+
1
);
SummaryStringRecursive
(
ss
,
shape
,
cursor
,
depth
+
1
,
use_comma
);
}
}
}
...
...
@@ -487,29 +508,35 @@ std::string Tensor::GetShapeAndDataTypeInfo() const {
return
buf
.
str
();
}
std
::
string
Tensor
::
ToString
()
const
{
constexpr
int
small_tensor_size
=
30
;
std
::
string
Tensor
::
ToStringInternal
(
int
limit_size
)
const
{
std
::
ostringstream
buf
;
auto
dtype
=
Dtype
();
MS_EXCEPTION_IF_NULL
(
dtype
);
data_sync
();
buf
<<
"Tensor(shape="
<<
ShapeToString
(
shape_
)
<<
", dtype="
<<
dtype
->
ToString
()
<<
','
;
if
(
DataSize
()
<
small_tensor
_size
)
{
if
(
limit_size
<=
0
||
DataSize
()
<
limit
_size
)
{
// Only print data for small tensor.
buf
<<
((
data
().
ndim
()
>
1
)
?
'\n'
:
' '
)
<<
data
().
ToString
(
data_type_
,
shape_
)
<<
')'
;
buf
<<
((
data
().
ndim
()
>
1
)
?
'\n'
:
' '
)
<<
data
().
ToString
(
data_type_
,
shape_
,
false
)
<<
')'
;
}
else
{
buf
<<
" [...])"
;
}
return
buf
.
str
();
}
std
::
string
Tensor
::
ToString
()
const
{
constexpr
int
small_tensor_size
=
30
;
return
ToStringInternal
(
small_tensor_size
);
}
std
::
string
Tensor
::
ToStringNoLimit
()
const
{
return
ToStringInternal
(
0
);
}
std
::
string
Tensor
::
ToStringRepr
()
const
{
std
::
ostringstream
buf
;
auto
dtype
=
Dtype
();
MS_EXCEPTION_IF_NULL
(
dtype
);
data_sync
();
buf
<<
"Tensor(shape="
<<
ShapeToString
(
shape_
)
<<
", dtype="
<<
dtype
->
ToString
()
<<
','
<<
((
data
().
ndim
()
>
1
)
?
'\n'
:
' '
)
<<
data
().
ToString
(
data_type_
,
shape_
)
<<
')'
;
<<
((
data
().
ndim
()
>
1
)
?
'\n'
:
' '
)
<<
data
().
ToString
(
data_type_
,
shape_
,
true
)
<<
')'
;
return
buf
.
str
();
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/core/ir/tensor.h
浏览文件 @
41009707
...
...
@@ -53,7 +53,7 @@ class TensorData {
/// Is data equals.
virtual
bool
equals
(
const
TensorData
&
other
)
const
=
0
;
/// To string.
virtual
std
::
string
ToString
(
const
TypeId
type
,
const
ShapeVector
&
shape
)
const
=
0
;
virtual
std
::
string
ToString
(
const
TypeId
type
,
const
ShapeVector
&
shape
,
bool
use_comma
)
const
=
0
;
};
using
TensorDataPtr
=
std
::
shared_ptr
<
TensorData
>
;
...
...
@@ -208,6 +208,10 @@ class Tensor : public MetaTensor {
std
::
string
GetShapeAndDataTypeInfo
()
const
;
std
::
string
ToStringInternal
(
int
limit_size
)
const
;
std
::
string
ToStringNoLimit
()
const
;
std
::
string
ToString
()
const
override
;
std
::
string
ToStringRepr
()
const
;
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部