Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e5c0e052
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e5c0e052
编写于
6月 19, 2020
作者:
R
rick_sanchez
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix codex and codereview.
上级
8867c67d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
106 addition
and
74 deletion
+106
-74
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+102
-73
mindspore/ccsrc/pynative/pynative_execute.h
mindspore/ccsrc/pynative/pynative_execute.h
+4
-1
未找到文件。
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
e5c0e052
...
...
@@ -110,7 +110,40 @@ py::object GetTupleObj(const py::object &obj) {
return
obj_tuple
;
}
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
args
,
py
::
tuple
*
out_args
)
{
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
GetTypeIndex
(
const
std
::
vector
<
SignatureEnumDType
>
&
dtypes
)
{
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_indexes
;
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
auto
it
=
type_indexes
.
find
(
dtypes
[
i
]);
if
(
it
==
type_indexes
.
end
())
{
(
void
)
type_indexes
.
insert
(
std
::
make_pair
(
dtypes
[
i
],
std
::
vector
<
size_t
>
{
i
}));
}
else
{
it
->
second
.
push_back
(
i
);
}
}
return
type_indexes
;
}
std
::
map
<
SignatureEnumDType
,
size_t
>
GetDstType
(
const
py
::
tuple
&
py_args
,
const
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
&
type_indexes
)
{
std
::
map
<
SignatureEnumDType
,
size_t
>
dst_type
;
for
(
auto
it
=
type_indexes
.
begin
();
it
!=
type_indexes
.
end
();
(
void
)
++
it
)
{
auto
type
=
it
->
first
;
auto
indexes
=
it
->
second
;
if
(
indexes
.
size
()
<
2
)
{
continue
;
}
size_t
m_index
=
indexes
[
0
];
for
(
size_t
i
=
1
;
i
<
indexes
.
size
();
++
i
)
{
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
indexes
[
i
]]))
{
m_index
=
indexes
[
i
];
}
}
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
m_index
));
}
return
dst_type
;
}
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
args
,
py
::
tuple
*
const
out_args
)
{
auto
&
py_args
=
*
out_args
;
py
::
tuple
input_mask
(
args
.
size
());
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
...
...
@@ -129,30 +162,8 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
if
(
dtypes
.
size
()
==
0
||
static_cast
<
int
>
(
dtypes
.
size
())
==
empty_dtype_count
)
{
return
input_mask
;
}
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_indexs
;
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
auto
it
=
type_indexs
.
find
(
dtypes
[
i
]);
if
(
it
==
type_indexs
.
end
())
{
(
void
)
type_indexs
.
insert
(
std
::
make_pair
(
dtypes
[
i
],
std
::
vector
<
size_t
>
{
i
}));
}
else
{
it
->
second
.
push_back
(
i
);
}
}
std
::
map
<
SignatureEnumDType
,
size_t
>
dst_type
;
for
(
auto
it
=
type_indexs
.
begin
();
it
!=
type_indexs
.
end
();
(
void
)
++
it
)
{
auto
type
=
it
->
first
;
auto
indexs
=
it
->
second
;
if
(
indexs
.
size
()
<
2
)
{
continue
;
}
size_t
m_index
=
indexs
[
0
];
for
(
size_t
i
=
1
;
i
<
indexs
.
size
();
++
i
)
{
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
indexs
[
i
]]))
{
m_index
=
indexs
[
i
];
}
}
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
m_index
));
}
auto
type_indexes
=
GetTypeIndex
(
dtypes
);
auto
dst_type
=
GetDstType
(
py_args
,
type_indexes
);
for
(
size_t
i
=
0
;
i
<
py_args
.
size
();
++
i
)
{
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
if
(
it
!=
dst_type
.
end
()
&&
it
->
second
!=
i
&&
...
...
@@ -542,28 +553,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
return
curr_g_
->
NewCNode
(
tuple_get_item_inputs
);
}
py
::
tuple
RunOp
(
const
py
::
args
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"RunOp start"
<<
args
.
size
();
py
::
object
result
;
// returns a null py::tuple on error
py
::
tuple
err_ret
(
0
);
PynativeStatusCode
status
=
PYNATIVE_UNKNOWN_STATE
;
OpExecInfoPtr
op_exec_info
=
GenerateOpExecInfo
(
args
);
MS_EXCEPTION_IF_NULL
(
op_exec_info
);
if
(
op_exec_info
->
abstract
!=
nullptr
)
{
py
::
dict
output
=
abstract
::
ConvertAbstractToPython
(
op_exec_info
->
abstract
);
if
(
!
output
[
"value"
].
is_none
())
{
py
::
tuple
value_ret
(
1
);
value_ret
[
0
]
=
output
[
"value"
];
return
value_ret
;
}
if
(
py
::
hasattr
(
op_exec_info
->
py_primitive
->
GetPyObj
(),
"const_value"
))
{
py
::
tuple
value_ret
(
1
);
value_ret
[
0
]
=
""
;
return
value_ret
;
}
}
py
::
tuple
RunOp
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
)
{
MS_LOG
(
INFO
)
<<
"RunOp start, op name is: "
<<
op_exec_info
->
op_name
;
mindspore
::
parse
::
python_adapter
::
set_python_env_flag
(
true
);
MsBackendPolicy
backend_policy
;
...
...
@@ -584,7 +574,10 @@ py::tuple RunOp(const py::args &args) {
if
(
vm_operators
.
find
(
op_exec_info
->
op_name
)
!=
vm_operators
.
end
())
{
backend_policy
=
kMsBackendVmOnly
;
}
result
=
RunOpWithBackendPolicy
(
backend_policy
,
op_exec_info
,
&
status
);
PynativeStatusCode
status
=
PYNATIVE_UNKNOWN_STATE
;
// returns a null py::tuple on error
py
::
tuple
err_ret
(
0
);
py
::
object
result
=
RunOpWithBackendPolicy
(
backend_policy
,
op_exec_info
,
&
status
);
if
(
status
!=
PYNATIVE_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Failed to run "
<<
op_exec_info
->
op_name
;
return
err_ret
;
...
...
@@ -599,6 +592,26 @@ py::tuple RunOp(const py::args &args) {
return
result
;
}
py
::
tuple
RunOp
(
const
py
::
args
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"RunOp start"
<<
args
.
size
();
OpExecInfoPtr
op_exec_info
=
GenerateOpExecInfo
(
args
);
MS_EXCEPTION_IF_NULL
(
op_exec_info
);
if
(
op_exec_info
->
abstract
!=
nullptr
)
{
py
::
dict
output
=
abstract
::
ConvertAbstractToPython
(
op_exec_info
->
abstract
);
if
(
!
output
[
"value"
].
is_none
())
{
py
::
tuple
value_ret
(
1
);
value_ret
[
0
]
=
output
[
"value"
];
return
value_ret
;
}
if
(
py
::
hasattr
(
op_exec_info
->
py_primitive
->
GetPyObj
(),
"const_value"
))
{
py
::
tuple
value_ret
(
1
);
value_ret
[
0
]
=
""
;
return
value_ret
;
}
}
return
RunOp
(
op_exec_info
,
args
);
}
void
ClearPyNativeSession
()
{
session
=
nullptr
;
}
PynativeExecutor
::~
PynativeExecutor
()
{
ClearRes
();
}
...
...
@@ -732,7 +745,11 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
return
;
}
}
EndGraphByOutId
(
out_id
,
cell
,
out
,
args
);
}
void
PynativeExecutor
::
EndGraphByOutId
(
const
std
::
string
&
out_id
,
const
py
::
object
&
cell
,
const
py
::
object
&
out
,
const
py
::
args
&
args
)
{
AnfNodePtr
output_node
;
if
(
graph_info_map_
[
curr_g_
].
param_map
.
count
(
out_id
))
{
output_node
=
graph_info_map_
[
curr_g_
].
param_map
[
out_id
];
...
...
@@ -776,27 +793,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
}
void
PynativeExecutor
::
GradNet
(
const
GradOperationPtr
&
grad
,
const
py
::
object
&
cell
,
const
py
::
object
&
weights
,
const
py
::
args
&
args
)
{
MS_LOG
(
INFO
)
<<
"GradNet start"
<<
args
.
size
();
std
::
size_t
size
=
args
.
size
();
auto
cell_id
=
GetId
(
cell
);
if
(
graph_map_
.
count
(
cell_id
)
!=
0
)
{
MS_LOG
(
DEBUG
)
<<
"GradNet already compiled"
;
return
;
}
MS_LOG
(
DEBUG
)
<<
"GradNet first compiled"
;
std
::
vector
<
AnfNodePtr
>
new_params
;
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
ParameterPtr
p
=
std
::
make_shared
<
Parameter
>
(
df_builder_
);
new_params
.
push_back
(
p
);
}
MS_LOG
(
DEBUG
)
<<
"GradNet start weight size"
<<
df_builder_
->
parameters
().
size
();
new_params
.
insert
(
new_params
.
end
(),
df_builder_
->
parameters
().
begin
(),
df_builder_
->
parameters
().
end
());
df_builder_
->
set_parameters
(
new_params
);
resource_
->
manager
()
->
SetParameters
(
df_builder_
,
new_params
);
std
::
vector
<
AnfNodePtr
>
PynativeExecutor
::
GetWeightsArgs
(
const
py
::
object
&
weights
)
{
std
::
vector
<
AnfNodePtr
>
w_args
;
if
(
py
::
hasattr
(
weights
,
"__parameter_tuple__"
))
{
auto
tuple
=
weights
.
cast
<
py
::
tuple
>
();
...
...
@@ -821,12 +818,12 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"training not paramter_tuple"
;
}
MS_EXCEPTION_IF_NULL
(
resource_
->
func_graph
());
auto
g
=
GradGraph
(
resource_
->
func_graph
(),
grad
,
w_args
,
size
);
resource_
->
set_func_graph
(
g
);
return
w_args
;
}
// get the parameters items and add the value to args_spec
abstract
::
AbstractBasePtrList
PynativeExecutor
::
GetArgsSpec
(
const
py
::
args
&
args
)
{
abstract
::
AbstractBasePtrList
args_spec
;
std
::
size_t
size
=
args
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
size
;
i
++
)
{
ValuePtr
converted
=
nullptr
;
bool
succ
=
parse
::
ConvertData
(
args
[
i
],
&
converted
);
...
...
@@ -852,6 +849,38 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
param_node
->
set_abstract
(
ptr
);
}
}
return
args_spec
;
}
void
PynativeExecutor
::
GradNet
(
const
GradOperationPtr
&
grad
,
const
py
::
object
&
cell
,
const
py
::
object
&
weights
,
const
py
::
args
&
args
)
{
MS_LOG
(
INFO
)
<<
"GradNet start"
<<
args
.
size
();
std
::
size_t
size
=
args
.
size
();
auto
cell_id
=
GetId
(
cell
);
if
(
graph_map_
.
count
(
cell_id
)
!=
0
)
{
MS_LOG
(
DEBUG
)
<<
"GradNet already compiled"
;
return
;
}
MS_LOG
(
DEBUG
)
<<
"GradNet first compiled"
;
std
::
vector
<
AnfNodePtr
>
new_params
;
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
ParameterPtr
p
=
std
::
make_shared
<
Parameter
>
(
df_builder_
);
new_params
.
push_back
(
p
);
}
MS_LOG
(
DEBUG
)
<<
"GradNet start weight size"
<<
df_builder_
->
parameters
().
size
();
new_params
.
insert
(
new_params
.
end
(),
df_builder_
->
parameters
().
begin
(),
df_builder_
->
parameters
().
end
());
df_builder_
->
set_parameters
(
new_params
);
resource_
->
manager
()
->
SetParameters
(
df_builder_
,
new_params
);
std
::
vector
<
AnfNodePtr
>
w_args
=
GetWeightsArgs
(
weights
);
MS_EXCEPTION_IF_NULL
(
resource_
->
func_graph
());
auto
g
=
GradGraph
(
resource_
->
func_graph
(),
grad
,
w_args
,
size
);
resource_
->
set_func_graph
(
g
);
// get the parameters items and add the value to args_spec
abstract
::
AbstractBasePtrList
args_spec
=
GetArgsSpec
(
args
);
MS_LOG
(
DEBUG
)
<<
"Args_spec size"
<<
args_spec
.
size
();
resource_
->
set_args_spec
(
args_spec
);
...
...
mindspore/ccsrc/pynative/pynative_execute.h
浏览文件 @
e5c0e052
...
...
@@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py
::
tuple
RunOp
(
const
py
::
args
&
args
);
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
py
::
tuple
*
out_args
);
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
py
::
tuple
*
const
out_args
);
void
ClearPyNativeSession
();
...
...
@@ -67,6 +67,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
}
void
NewGraph
(
const
py
::
object
&
cell
,
const
py
::
args
&
args
);
void
EndGraph
(
const
py
::
object
&
cell
,
const
py
::
object
&
out
,
const
py
::
args
&
args
);
void
EndGraphByOutId
(
const
std
::
string
&
out_id
,
const
py
::
object
&
cell
,
const
py
::
object
&
out
,
const
py
::
args
&
args
);
std
::
vector
<
AnfNodePtr
>
GetWeightsArgs
(
const
py
::
object
&
weights
);
abstract
::
AbstractBasePtrList
GetArgsSpec
(
const
py
::
args
&
args
);
void
GradNet
(
const
GradOperationPtr
&
grad
,
const
py
::
object
&
cell
,
const
py
::
object
&
weights
,
const
py
::
args
&
args
);
void
Clear
(
const
std
::
string
&
flag
=
""
);
void
Clean
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录