Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1971aea8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1971aea8
编写于
5月 16, 2020
作者:
Z
zjun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix tensor print order
上级
93d95d70
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
34 addition
and
33 deletion
+34
-33
mindspore/ccsrc/utils/tensorprint_utils.cc
mindspore/ccsrc/utils/tensorprint_utils.cc
+34
-33
未找到文件。
mindspore/ccsrc/utils/tensorprint_utils.cc
浏览文件 @
1971aea8
...
...
@@ -50,6 +50,7 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons
if
(
tensor_shape
==
nullptr
)
{
return
false
;
}
MS_EXCEPTION_IF_NULL
(
dims
);
std
::
string
shape_str
=
input_shape_str
;
if
(
shape_str
.
size
()
<=
2
)
{
return
false
;
...
...
@@ -71,6 +72,8 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons
bool
PrintTensorToString
(
const
char
*
str_data_ptr
,
mindspore
::
tensor
::
Tensor
*
const
print_tensor
,
const
size_t
&
memory_size
)
{
MS_EXCEPTION_IF_NULL
(
str_data_ptr
);
MS_EXCEPTION_IF_NULL
(
print_tensor
);
auto
*
tensor_data_ptr
=
static_cast
<
uint8_t
*>
(
print_tensor
->
data_c
(
true
));
MS_EXCEPTION_IF_NULL
(
tensor_data_ptr
);
auto
cp_ret
=
...
...
@@ -83,55 +86,57 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
}
template
<
typename
T
>
void
PrintScalarToString
(
const
char
*
str_data_ptr
,
const
string
&
tensor_type
)
{
void
PrintScalarToString
(
const
char
*
str_data_ptr
,
const
string
&
tensor_type
,
std
::
ostringstream
*
buf
)
{
MS_EXCEPTION_IF_NULL
(
str_data_ptr
);
MS_EXCEPTION_IF_NULL
(
buf
);
const
T
*
data_ptr
=
reinterpret_cast
<
const
T
*>
(
str_data_ptr
);
std
::
ostringstream
buf_scalar
;
buf_scalar
<<
"Tensor shape :1 "
<<
tensor_type
;
buf_scalar
<<
"
\n
val:"
;
buf_scalar
<<
*
data_ptr
;
std
::
cout
<<
buf_scalar
.
str
()
<<
std
::
endl
;
*
buf
<<
"Tensor shape:[1] "
<<
tensor_type
;
*
buf
<<
"
\n
val:"
;
*
buf
<<
*
data_ptr
<<
"
\n
"
;
}
void
PrintScalarToBoolString
(
const
char
*
str_data_ptr
,
const
string
&
tensor_type
)
{
void
PrintScalarToBoolString
(
const
char
*
str_data_ptr
,
const
string
&
tensor_type
,
std
::
ostringstream
*
buf
)
{
MS_EXCEPTION_IF_NULL
(
str_data_ptr
);
MS_EXCEPTION_IF_NULL
(
buf
);
const
bool
*
data_ptr
=
reinterpret_cast
<
const
bool
*>
(
str_data_ptr
);
std
::
ostringstream
buf_scalar
;
buf_scalar
<<
"Tensor shape :1 "
<<
tensor_type
;
buf_scalar
<<
"
\n
val:"
;
if
(
*
data_ptr
==
true
)
{
buf_scalar
<<
"True"
;
*
buf
<<
"Tensor shape:[1] "
<<
tensor_type
;
*
buf
<<
"
\n
val:"
;
if
(
*
data_ptr
)
{
*
buf
<<
"True
\n
"
;
}
else
{
buf_scalar
<<
"False
"
;
*
buf
<<
"False
\n
"
;
}
std
::
cout
<<
buf_scalar
.
str
()
<<
std
::
endl
;
}
void
convertDataItem2Scalar
(
const
char
*
str_data_ptr
,
const
string
&
tensor_type
)
{
void
convertDataItem2Scalar
(
const
char
*
str_data_ptr
,
const
string
&
tensor_type
,
std
::
ostringstream
*
buf
)
{
MS_EXCEPTION_IF_NULL
(
str_data_ptr
);
MS_EXCEPTION_IF_NULL
(
buf
);
auto
type_iter
=
print_type_map
.
find
(
tensor_type
);
auto
type_id
=
type_iter
->
second
;
if
(
type_id
==
TypeId
::
kNumberTypeBool
)
{
PrintScalarToBoolString
(
str_data_ptr
,
tensor_type
);
PrintScalarToBoolString
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeInt8
)
{
PrintScalarToString
<
int8_t
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
int8_t
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeUInt8
)
{
PrintScalarToString
<
uint8_t
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
uint8_t
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeInt16
)
{
PrintScalarToString
<
int16_t
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
int16_t
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeUInt16
)
{
PrintScalarToString
<
uint16_t
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
uint16_t
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeInt32
)
{
PrintScalarToString
<
int32_t
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
int32_t
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeUInt32
)
{
PrintScalarToString
<
uint32_t
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
uint32_t
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeInt64
)
{
PrintScalarToString
<
int64_t
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
int64_t
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeUInt64
)
{
PrintScalarToString
<
uint64_t
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
uint64_t
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeFloat16
)
{
PrintScalarToString
<
float16
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
float16
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeFloat32
)
{
PrintScalarToString
<
float
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
float
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
if
(
type_id
==
TypeId
::
kNumberTypeFloat64
)
{
PrintScalarToString
<
double
>
(
str_data_ptr
,
tensor_type
);
PrintScalarToString
<
double
>
(
str_data_ptr
,
tensor_type
,
buf
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot print scalar because of unsupport data type: "
<<
tensor_type
<<
"."
;
}
...
...
@@ -142,11 +147,7 @@ bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
if
(
type_iter
==
type_size_map
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"type of scalar to print is not support."
;
}
if
(
str_len
!=
type_iter
->
second
)
{
return
false
;
}
return
true
;
return
str_len
==
type_iter
->
second
;
}
#ifndef NO_DLIB
...
...
@@ -166,7 +167,7 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
if
(
!
judgeLengthValid
(
str_data_ptr
->
size
(),
item
.
tensorType_
))
{
MS_LOG
(
EXCEPTION
)
<<
"Print op receive data length is invalid."
;
}
convertDataItem2Scalar
(
str_data_ptr
->
data
(),
item
.
tensorType_
);
convertDataItem2Scalar
(
str_data_ptr
->
data
(),
item
.
tensorType_
,
&
buf
);
continue
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录