Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1835ff33
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1835ff33
编写于
7月 24, 2020
作者:
L
lvliang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize-the-time-of-producting-cacha-key-in-pynative
上级
0b407dfe
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
16 deletion
+18
-16
mindspore/ccsrc/pipeline/pynative/base.h
mindspore/ccsrc/pipeline/pynative/base.h
+2
-1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+15
-14
tests/st/pynative/test_pynative_resnet50.py
tests/st/pynative/test_pynative_resnet50.py
+1
-1
未找到文件。
mindspore/ccsrc/pipeline/pynative/base.h
浏览文件 @
1835ff33
...
...
@@ -49,8 +49,9 @@ enum PynativeStatusCode {
enum
RunOpArgsEnum
{
PY_PRIM
=
0
,
PY_NAME
,
PY_INPUTS
,
PY_ARGS_NUM
};
struct
OpExecInfo
{
PrimitivePyPtr
py_primitive
;
std
::
string
op_name
;
std
::
string
prim_id
;
PrimitivePyPtr
py_primitive
;
AbstractBasePtr
abstract
;
ValuePtr
value
=
nullptr
;
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
1835ff33
...
...
@@ -144,6 +144,7 @@ static std::string GetId(const py::object &obj) {
static
std
::
string
GetOpId
(
const
OpExecInfoPtr
&
op_exec_info
)
{
auto
id
=
GetId
(
op_exec_info
->
py_primitive
->
GetPyObj
());
op_exec_info
->
prim_id
=
id
;
return
id
;
}
...
...
@@ -306,6 +307,8 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
auto
inst
=
PynativeExecutor
::
GetInstance
();
if
(
inst
->
grad_flag
())
{
op_exec_info
->
value
=
inst
->
GetForwardValue
(
op_exec_info
);
}
else
{
(
void
)
GetOpId
(
op_exec_info
);
}
op_exec_info
->
op_inputs
=
args
[
PY_INPUTS
];
ConvertInputs
(
prim
,
args
[
PY_INPUTS
],
op_exec_info
);
...
...
@@ -317,23 +320,21 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL
(
op_exec_info
);
std
::
string
graph_info
;
// get input tensor info
size_t
input_num
=
op_exec_info
->
op_inputs
.
size
();
for
(
size_t
index
=
0
;
index
<
input_num
;
++
index
)
{
auto
input
=
op_exec_info
->
op_inputs
[
index
];
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
input
))
{
auto
tensor_ptr
=
py
::
cast
<
tensor
::
TensorPtr
>
(
input
);
(
void
)
graph_info
.
append
(
tensor_ptr
->
GetShapeAndDataTypeInfo
()
+
"_"
);
}
for
(
const
auto
&
tensor
:
input_tensors
)
{
MS_EXCEPTION_IF_NULL
(
tensor
);
auto
tensor_shape
=
tensor
->
shape
();
(
void
)
std
::
for_each
(
tensor_shape
.
begin
(),
tensor_shape
.
end
(),
[
&
](
const
auto
&
dim
)
{
(
void
)
graph_info
.
append
(
std
::
to_string
(
dim
)
+
"_"
);
});
(
void
)
graph_info
.
append
(
std
::
to_string
(
tensor
->
data_type
())
+
"_"
);
}
// get prim and abstract info
MS_EXCEPTION_IF_NULL
(
op_exec_info
->
abstract
);
(
void
)
graph_info
.
append
(
std
::
to_string
((
uintptr_t
)(
op_exec_info
->
py_primitive
.
get
()))
+
"_"
+
op_exec_info
->
abstract
->
ToString
());
(
void
)
graph_info
.
append
(
op_exec_info
->
prim_id
+
"_"
);
// get attr info
auto
attr_map
=
op_exec_info
->
py_primitive
->
evaluate_added_attrs
();
for
(
const
auto
&
element
:
attr_map
)
{
(
void
)
graph_info
.
append
(
element
.
second
->
ToString
()
+
" "
);
}
const
auto
&
op_prim
=
op_exec_info
->
py_primitive
;
MS_EXCEPTION_IF_NULL
(
op_prim
);
const
auto
&
attr_map
=
op_prim
->
evaluate_added_attrs
();
(
void
)
std
::
for_each
(
attr_map
.
begin
(),
attr_map
.
end
(),
[
&
](
const
auto
&
element
)
{
(
void
)
graph_info
.
append
(
element
.
second
->
ToString
()
+
"_"
);
});
return
graph_info
;
}
...
...
tests/st/pynative/test_pynative_resnet50.py
浏览文件 @
1835ff33
...
...
@@ -428,7 +428,7 @@ def test_pynative_resnet50():
end_time
=
time
.
time
()
cost_time
=
end_time
-
start_time
print
(
"======step: "
,
step
,
" loss: "
,
loss_output
.
asnumpy
(),
" cost time: "
,
cost_time
)
if
step
>
1
and
cost_time
>
0.
32
:
if
step
>
1
and
cost_time
>
0.
21
:
exceed_num
=
exceed_num
+
1
assert
exceed_num
<
10
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录