Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindspore
提交
a0956538
M
mindspore
项目概览
MindSpore
/
mindspore
通知
35
Star
15
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
a0956538
编写于
7月 13, 2020
作者:
K
kpy
提交者:
kuangpeiyu
8月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize infer in pynative mode
上级
61639d90
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
236 addition
and
145 deletion
+236
-145
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+1
-1
mindspore/ccsrc/pipeline/jit/resource.cc
mindspore/ccsrc/pipeline/jit/resource.cc
+1
-0
mindspore/ccsrc/pipeline/pynative/base.h
mindspore/ccsrc/pipeline/pynative/base.h
+3
-3
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+168
-96
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
+19
-4
mindspore/ccsrc/utils/primitive_py.cc
mindspore/ccsrc/utils/primitive_py.cc
+1
-0
mindspore/core/ir/primitive.cc
mindspore/core/ir/primitive.cc
+25
-0
mindspore/core/ir/primitive.h
mindspore/core/ir/primitive.h
+13
-16
mindspore/core/utils/profile.cc
mindspore/core/utils/profile.cc
+1
-1
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-1
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+1
-1
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+1
-1
tests/ut/cpp/pynative/pynative_execute_test.cc
tests/ut/cpp/pynative/pynative_execute_test.cc
+1
-21
未找到文件。
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
a0956538
...
...
@@ -351,7 +351,7 @@ bool ExecuteAction(const ResourcePtr &res) {
}
auto
graph_id
=
res
->
results
()[
kOutput
].
cast
<
GraphId
>
();
std
::
shared_ptr
<
compile
::
Backend
>
bc_ptr
=
res
->
results
()[
kBackend
].
cast
<
std
::
shared_ptr
<
compile
::
Backend
>>
();
std
::
shared_ptr
<
compile
::
MsBackend
>
msbc_ptr
=
std
::
dynamic_pointer_cast
<
compile
::
MsBackend
>
(
bc_ptr
);
compile
::
MsBackend
*
msbc_ptr
=
std
::
dynamic_pointer_cast
<
compile
::
MsBackend
>
(
bc_ptr
).
get
(
);
MS_EXCEPTION_IF_NULL
(
msbc_ptr
);
compile
::
VmEvalFuncPtr
run
=
std
::
make_shared
<
compile
::
VmEvalFunc
>
([
msbc_ptr
,
graph_id
](
const
VectorRef
&
args
)
->
BaseRef
{
...
...
mindspore/ccsrc/pipeline/jit/resource.cc
浏览文件 @
a0956538
...
...
@@ -205,6 +205,7 @@ Resource::Resource(const py::object &obj)
Resource
::~
Resource
()
{
MS_LOG
(
DEBUG
)
<<
"Resource clear"
;
std
::
unordered_map
<
std
::
string
,
Any
>
().
swap
(
results_
);
// If exit normally, these global variables will be cleaned
// in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
// these global variables may not being cleaned, it may
...
...
mindspore/ccsrc/pipeline/pynative/base.h
浏览文件 @
a0956538
...
...
@@ -54,12 +54,12 @@ struct OpExecInfo {
AbstractBasePtr
abstract
;
ValuePtr
value
=
nullptr
;
py
::
tuple
op_inputs
;
py
::
tuple
inputs_mask
;
py
::
list
op_inputs
;
py
::
dict
op_attrs
;
std
::
vector
<
bool
>
inputs_mask
;
};
using
OpExecInfoPtr
=
std
::
shared_ptr
<
OpExecInfo
>
;
OpExecInfoPtr
GenerateOpExecInfo
(
const
py
::
args
&
args
,
py
::
list
*
const
out_args
);
OpExecInfoPtr
GenerateOpExecInfo
(
const
py
::
args
&
args
);
const
std
::
set
<
std
::
string
>
ignore_infer_prim
=
{
"make_ref"
,
"mixed_precision_cast"
};
}
// namespace pynative
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
a0956538
此差异已折叠。
点击以展开。
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
浏览文件 @
a0956538
...
...
@@ -41,12 +41,20 @@ namespace py = pybind11;
using
ResourcePtr
=
std
::
shared_ptr
<
pipeline
::
Resource
>
;
using
GradOperationPtr
=
std
::
shared_ptr
<
prim
::
GradOperation
>
;
struct
PrimAbsInfo
{
abstract
::
AbstractBasePtr
abs
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs
;
};
using
AbstractListMap
=
std
::
unordered_map
<
abstract
::
AbstractBasePtrList
,
PrimAbsInfo
,
abstract
::
AbstractBasePtrListHasher
,
abstract
::
AbstractBasePtrListEqual
>
;
py
::
object
RunOpInVM
(
const
OpExecInfoPtr
&
op_exec_info
,
PynativeStatusCode
*
status
);
py
::
tuple
RunOp
(
const
py
::
args
&
args
);
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
py
::
tuple
*
const
out_args
,
py
::
list
*
const
out_args_list
);
void
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
py
::
tuple
*
const
out_args
,
py
::
list
*
const
out_args_list
);
void
ClearPyNativeSession
();
...
...
@@ -82,7 +90,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void
ClearRes
();
bool
grad_flag
()
{
return
grad_flag_
;
}
void
set_grad_flag
(
bool
flag
)
{
grad_flag_
=
flag
;
}
AnfNodePtr
GetInput
(
const
py
::
object
&
obj
,
const
py
::
object
&
op_mask
);
AnfNodePtr
GetInput
(
const
py
::
object
&
obj
,
bool
op_mask
);
AnfNodePtr
GetObjNode
(
const
py
::
object
&
obj
);
FuncGraphPtr
curr_g
()
{
return
curr_g_
;
}
void
set_pyobj
(
FuncGraphPtr
g
,
const
std
::
string
obj
)
{
graph_info_map_
[
g
].
objects
.
push_back
(
obj
);
}
...
...
@@ -95,11 +103,14 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
,
std
::
vector
<
int
>
index
)
{
graph_info_map_
[
g
].
obj_node_map
[
obj
]
=
std
::
make_pair
(
node
,
index
);
}
CNodePtr
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
,
const
py
::
tuple
&
out
);
AnfNodePtr
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
std
::
vector
<
bool
>
*
op_masks
,
abstract
::
AbstractBasePtrList
*
args_spec_list
);
void
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
object
&
out
,
const
AnfNodePtr
&
cnode
);
ValuePtr
GetForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
);
void
SaveOpForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
,
const
ValuePtr
&
value
);
void
SaveForwardResult
(
const
CNodePtr
&
cnode
,
const
py
::
object
&
out
);
void
SaveAllResult
(
const
OpExecInfoPtr
&
op_exec_info
,
const
CNodePtr
&
cnode
,
const
py
::
tuple
&
out
);
py
::
object
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
);
void
Pushp
();
...
...
@@ -108,6 +119,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
size_t
arg_size
);
void
SetTupleOutput
(
const
py
::
object
&
obj
,
const
AnfNodePtr
&
cnode
,
std
::
vector
<
int
>
idx
);
AnfNodePtr
MakeValueNode
(
const
py
::
object
&
obj
,
const
std
::
string
&
obj_id
);
py
::
tuple
RunOpInner
(
const
py
::
args
&
args
);
py
::
tuple
RunOpInner
(
const
OpExecInfoPtr
&
op_exec_info
);
~
PynativeExecutor
();
...
...
@@ -123,10 +136,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std
::
unordered_map
<
FuncGraphPtr
,
GraphInfo
>
graph_info_map_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
op_forward_map_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
op_id_map_
;
std
::
unordered_map
<
std
::
string
,
abstract
::
AbstractBasePtr
>
node_abs_map_
;
std
::
stack
<
FuncGraphPtr
>
graph_p_
;
FuncGraphPtr
top_g_
;
FuncGraphPtr
df_builder_
;
FuncGraphPtr
curr_g_
;
std
::
unordered_map
<
std
::
string
,
AbstractListMap
>
prim_abs_list
;
};
using
PynativeExecutorPtr
=
std
::
shared_ptr
<
PynativeExecutor
>
;
...
...
mindspore/ccsrc/utils/primitive_py.cc
浏览文件 @
a0956538
...
...
@@ -220,6 +220,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.
def
(
"add_attr"
,
&
PrimitivePy
::
AddPyAttr
,
"add primitive attr"
)
.
def
(
"get_attr_dict"
,
&
PrimitivePy
::
GetAttrDict
,
"get primitive attr"
)
.
def
(
"set_prim_type"
,
&
PrimitivePy
::
set_prim_type
,
"Set primitive type."
)
.
def
(
"set_is_const_value"
,
&
PrimitivePy
::
set_is_const_value
,
"Set primitive is const value."
)
.
def
(
"set_signatures"
,
&
PrimitivePy
::
set_signatures
,
"Set primitive inputs signature."
)
.
def
(
"register_hook"
,
&
PrimitivePy
::
set_hook
,
"Set primitive hook function."
)
.
def
(
"set_instance_name"
,
&
PrimitivePy
::
set_instance_name
,
"Set primitive instance name."
);
...
...
mindspore/core/ir/primitive.cc
浏览文件 @
a0956538
...
...
@@ -21,6 +21,31 @@
namespace
mindspore
{
static
std
::
string
MakeId
()
{
// Use atomic to make id generator thread safe.
static
std
::
atomic
<
uint64_t
>
last_id
{
1
};
return
"P"
+
std
::
to_string
(
last_id
.
fetch_add
(
1
,
std
::
memory_order_relaxed
));
}
Primitive
::
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
,
const
PrimType
prim_type
)
:
Named
(
name
),
is_base_
(
is_base
),
has_signature_
(
false
),
prim_type_
(
prim_type
),
record_evaluate_add_attr_
(
false
),
is_const_value_
(
false
),
id_
(
MakeId
())
{}
Primitive
::
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
instance_name_
(
prim
.
instance_name_
),
is_base_
(
prim
.
is_base_
),
has_signature_
(
prim
.
has_signature_
),
prim_type_
(
prim
.
prim_type_
),
record_evaluate_add_attr_
(
false
),
id_
(
prim
.
id_
)
{}
abstract
::
AbstractBasePtr
Primitive
::
ToAbstract
()
{
return
std
::
make_shared
<
abstract
::
PrimitiveAbstractClosure
>
(
shared_from_base
<
Primitive
>
(),
nullptr
);
}
...
...
mindspore/core/ir/primitive.h
浏览文件 @
a0956538
...
...
@@ -40,22 +40,8 @@ enum PrimType {
class
Primitive
:
public
Named
{
public:
explicit
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
=
true
,
const
PrimType
prim_type
=
kPrimTypeBuiltIn
)
:
Named
(
name
),
is_base_
(
is_base
),
has_signature_
(
false
),
prim_type_
(
prim_type
),
record_evaluate_add_attr_
(
false
)
{}
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
instance_name_
(
prim
.
instance_name_
),
is_base_
(
prim
.
is_base_
),
has_signature_
(
prim
.
has_signature_
),
prim_type_
(
prim
.
prim_type_
),
record_evaluate_add_attr_
(
false
)
{}
explicit
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
=
true
,
const
PrimType
prim_type
=
kPrimTypeBuiltIn
);
Primitive
(
const
Primitive
&
prim
);
MS_DECLARE_PARENT
(
Primitive
,
Named
);
abstract
::
AbstractBasePtr
ToAbstract
();
abstract
::
AbstractBasePtr
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
);
...
...
@@ -91,6 +77,12 @@ class Primitive : public Named {
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
evaluate_added_attrs
()
const
{
return
evaluate_added_attrs_
;
}
void
set_evaluate_added_attrs
(
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
)
{
for
(
auto
&
attr
:
attrs
)
{
MS_LOG
(
INFO
)
<<
" set evalu attrl "
<<
name
()
<<
attr
.
first
;
attrs_
[
attr
.
first
]
=
attr
.
second
;
}
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool
HasAttr
()
const
{
return
!
attrs_
.
empty
();
}
...
...
@@ -117,6 +109,9 @@ class Primitive : public Named {
bool
is_base
()
const
{
return
is_base_
;
}
virtual
BaseRef
RunHookFunction
(
const
VectorRef
&
args
)
const
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
virtual
void
CopyHookFunction
(
const
PrimitivePtr
&
primitive
)
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
void
set_is_const_value
(
bool
value
)
{
is_const_value_
=
value
;
}
bool
is_const_value
()
const
{
return
is_const_value_
;
}
std
::
string
id
()
const
{
return
id_
;
}
protected:
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
...
...
@@ -128,6 +123,8 @@ class Primitive : public Named {
bool
has_signature_
;
PrimType
prim_type_
;
bool
record_evaluate_add_attr_
;
bool
is_const_value_
;
std
::
string
id_
{
""
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
PrimitivePtr
&
p
)
{
...
...
mindspore/core/utils/profile.cc
浏览文件 @
a0956538
...
...
@@ -335,7 +335,7 @@ static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, c
void
MsProfile
::
Print
()
{
GetProfile
()
->
Print
();
std
::
vector
<
std
::
string
>
items
=
{
"substitution."
,
"renormalize."
,
"replace."
,
"match."
,
"func_graph_cloner_run."
,
"meta_graph."
,
"manager."
};
"func_graph_cloner_run."
,
"meta_graph."
,
"manager."
,
"pynative"
};
std
::
vector
<
TimeInfoGroup
>
groups
(
items
.
size
()
+
1
);
const
auto
&
stat
=
GetSingleton
().
time_stat_
;
// group all time infos
...
...
mindspore/ops/functional.py
浏览文件 @
a0956538
...
...
@@ -28,7 +28,7 @@ hastype = Primitive('hastype')
cast
=
P
.
Cast
()
dtype
=
P
.
DType
()
isconstant
=
Primitive
(
'is_constant'
)
isconstant
.
add_prim_attr
(
'const_value'
,
True
)
isconstant
.
set_is_const_value
(
True
)
issubclass_
=
P
.
IsSubClass
()
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
a0956538
...
...
@@ -1027,7 +1027,7 @@ class InvertPermutation(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
):
"""init InvertPermutation"""
self
.
const_value
=
True
self
.
set_is_const_value
(
True
)
def
__infer__
(
self
,
x
):
x_shp
=
x
[
'shape'
]
...
...
mindspore/ops/primitive.py
浏览文件 @
a0956538
...
...
@@ -352,7 +352,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def
__init__
(
self
):
op_name
=
name
if
name
else
fn
.
__name__
PrimitiveWithInfer
.
__init__
(
self
,
op_name
)
self
.
const_value
=
True
self
.
set_is_const_value
(
True
)
def
infer_value
(
self
,
*
args
):
return
fn
(
*
args
)
...
...
tests/ut/cpp/pynative/pynative_execute_test.cc
浏览文件 @
a0956538
...
...
@@ -65,27 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
py
::
none
py_none
;
py
::
args
args
=
py
::
make_tuple
(
conv_obj
,
op_name
,
op_inputs
);
py
::
list
args_input
=
args
[
PY_INPUTS
];
return
GenerateOpExecInfo
(
args
,
&
args_input
);
}
TEST_F
(
TestPynativeExecute
,
TestRunOpInVM
)
{
py
::
tuple
result
;
PynativeStatusCode
status
;
auto
op_exec_info_ptr
=
ConstructOpExecInfo
();
result
=
pynative
::
RunOpInVM
(
op_exec_info_ptr
,
&
status
);
ASSERT_EQ
(
status
,
PYNATIVE_SUCCESS
);
}
TEST_F
(
TestPynativeExecute
,
TestRunOp
)
{
py
::
none
py_none
;
auto
op_exec_info_ptr
=
ConstructOpExecInfo
();
py
::
tuple
outputs
=
pynative
::
RunOp
(
py
::
make_tuple
(
op_exec_info_ptr
->
py_primitive
,
op_exec_info_ptr
->
op_name
,
op_exec_info_ptr
->
op_inputs
));
if
(
outputs
.
size
()
==
0
)
{
FAIL
();
}
else
{
SUCCEED
();
}
return
GenerateOpExecInfo
(
args
);
}
TEST_F
(
TestPynativeExecute
,
TestCreateContext
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录