Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a0956538
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0956538
编写于
7月 13, 2020
作者:
K
kpy
提交者:
kuangpeiyu
8月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize infer in pynative mode
上级
61639d90
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
236 addition
and
145 deletion
+236
-145
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+1
-1
mindspore/ccsrc/pipeline/jit/resource.cc
mindspore/ccsrc/pipeline/jit/resource.cc
+1
-0
mindspore/ccsrc/pipeline/pynative/base.h
mindspore/ccsrc/pipeline/pynative/base.h
+3
-3
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+168
-96
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
+19
-4
mindspore/ccsrc/utils/primitive_py.cc
mindspore/ccsrc/utils/primitive_py.cc
+1
-0
mindspore/core/ir/primitive.cc
mindspore/core/ir/primitive.cc
+25
-0
mindspore/core/ir/primitive.h
mindspore/core/ir/primitive.h
+13
-16
mindspore/core/utils/profile.cc
mindspore/core/utils/profile.cc
+1
-1
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-1
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+1
-1
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+1
-1
tests/ut/cpp/pynative/pynative_execute_test.cc
tests/ut/cpp/pynative/pynative_execute_test.cc
+1
-21
未找到文件。
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
a0956538
...
@@ -351,7 +351,7 @@ bool ExecuteAction(const ResourcePtr &res) {
...
@@ -351,7 +351,7 @@ bool ExecuteAction(const ResourcePtr &res) {
}
}
auto
graph_id
=
res
->
results
()[
kOutput
].
cast
<
GraphId
>
();
auto
graph_id
=
res
->
results
()[
kOutput
].
cast
<
GraphId
>
();
std
::
shared_ptr
<
compile
::
Backend
>
bc_ptr
=
res
->
results
()[
kBackend
].
cast
<
std
::
shared_ptr
<
compile
::
Backend
>>
();
std
::
shared_ptr
<
compile
::
Backend
>
bc_ptr
=
res
->
results
()[
kBackend
].
cast
<
std
::
shared_ptr
<
compile
::
Backend
>>
();
std
::
shared_ptr
<
compile
::
MsBackend
>
msbc_ptr
=
std
::
dynamic_pointer_cast
<
compile
::
MsBackend
>
(
bc_ptr
);
compile
::
MsBackend
*
msbc_ptr
=
std
::
dynamic_pointer_cast
<
compile
::
MsBackend
>
(
bc_ptr
).
get
(
);
MS_EXCEPTION_IF_NULL
(
msbc_ptr
);
MS_EXCEPTION_IF_NULL
(
msbc_ptr
);
compile
::
VmEvalFuncPtr
run
=
compile
::
VmEvalFuncPtr
run
=
std
::
make_shared
<
compile
::
VmEvalFunc
>
([
msbc_ptr
,
graph_id
](
const
VectorRef
&
args
)
->
BaseRef
{
std
::
make_shared
<
compile
::
VmEvalFunc
>
([
msbc_ptr
,
graph_id
](
const
VectorRef
&
args
)
->
BaseRef
{
...
...
mindspore/ccsrc/pipeline/jit/resource.cc
浏览文件 @
a0956538
...
@@ -205,6 +205,7 @@ Resource::Resource(const py::object &obj)
...
@@ -205,6 +205,7 @@ Resource::Resource(const py::object &obj)
Resource
::~
Resource
()
{
Resource
::~
Resource
()
{
MS_LOG
(
DEBUG
)
<<
"Resource clear"
;
MS_LOG
(
DEBUG
)
<<
"Resource clear"
;
std
::
unordered_map
<
std
::
string
,
Any
>
().
swap
(
results_
);
// If exit normally, these global variables will be cleaned
// If exit normally, these global variables will be cleaned
// in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
// in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
// these global variables may not being cleaned, it may
// these global variables may not being cleaned, it may
...
...
mindspore/ccsrc/pipeline/pynative/base.h
浏览文件 @
a0956538
...
@@ -54,12 +54,12 @@ struct OpExecInfo {
...
@@ -54,12 +54,12 @@ struct OpExecInfo {
AbstractBasePtr
abstract
;
AbstractBasePtr
abstract
;
ValuePtr
value
=
nullptr
;
ValuePtr
value
=
nullptr
;
py
::
tuple
op_inputs
;
py
::
list
op_inputs
;
py
::
tuple
inputs_mask
;
py
::
dict
op_attrs
;
py
::
dict
op_attrs
;
std
::
vector
<
bool
>
inputs_mask
;
};
};
using
OpExecInfoPtr
=
std
::
shared_ptr
<
OpExecInfo
>
;
using
OpExecInfoPtr
=
std
::
shared_ptr
<
OpExecInfo
>
;
OpExecInfoPtr
GenerateOpExecInfo
(
const
py
::
args
&
args
,
py
::
list
*
const
out_args
);
OpExecInfoPtr
GenerateOpExecInfo
(
const
py
::
args
&
args
);
const
std
::
set
<
std
::
string
>
ignore_infer_prim
=
{
"make_ref"
,
"mixed_precision_cast"
};
const
std
::
set
<
std
::
string
>
ignore_infer_prim
=
{
"make_ref"
,
"mixed_precision_cast"
};
}
// namespace pynative
}
// namespace pynative
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
a0956538
此差异已折叠。
点击以展开。
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
浏览文件 @
a0956538
...
@@ -41,12 +41,20 @@ namespace py = pybind11;
...
@@ -41,12 +41,20 @@ namespace py = pybind11;
using
ResourcePtr
=
std
::
shared_ptr
<
pipeline
::
Resource
>
;
using
ResourcePtr
=
std
::
shared_ptr
<
pipeline
::
Resource
>
;
using
GradOperationPtr
=
std
::
shared_ptr
<
prim
::
GradOperation
>
;
using
GradOperationPtr
=
std
::
shared_ptr
<
prim
::
GradOperation
>
;
struct
PrimAbsInfo
{
abstract
::
AbstractBasePtr
abs
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs
;
};
using
AbstractListMap
=
std
::
unordered_map
<
abstract
::
AbstractBasePtrList
,
PrimAbsInfo
,
abstract
::
AbstractBasePtrListHasher
,
abstract
::
AbstractBasePtrListEqual
>
;
py
::
object
RunOpInVM
(
const
OpExecInfoPtr
&
op_exec_info
,
PynativeStatusCode
*
status
);
py
::
object
RunOpInVM
(
const
OpExecInfoPtr
&
op_exec_info
,
PynativeStatusCode
*
status
);
py
::
tuple
RunOp
(
const
py
::
args
&
args
);
py
::
tuple
RunOp
(
const
py
::
args
&
args
);
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
py
::
tuple
*
const
out_args
,
void
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
py
::
tuple
*
const
out_args
,
py
::
list
*
const
out_args_list
);
py
::
list
*
const
out_args_list
);
void
ClearPyNativeSession
();
void
ClearPyNativeSession
();
...
@@ -82,7 +90,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
...
@@ -82,7 +90,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void
ClearRes
();
void
ClearRes
();
bool
grad_flag
()
{
return
grad_flag_
;
}
bool
grad_flag
()
{
return
grad_flag_
;
}
void
set_grad_flag
(
bool
flag
)
{
grad_flag_
=
flag
;
}
void
set_grad_flag
(
bool
flag
)
{
grad_flag_
=
flag
;
}
AnfNodePtr
GetInput
(
const
py
::
object
&
obj
,
const
py
::
object
&
op_mask
);
AnfNodePtr
GetInput
(
const
py
::
object
&
obj
,
bool
op_mask
);
AnfNodePtr
GetObjNode
(
const
py
::
object
&
obj
);
AnfNodePtr
GetObjNode
(
const
py
::
object
&
obj
);
FuncGraphPtr
curr_g
()
{
return
curr_g_
;
}
FuncGraphPtr
curr_g
()
{
return
curr_g_
;
}
void
set_pyobj
(
FuncGraphPtr
g
,
const
std
::
string
obj
)
{
graph_info_map_
[
g
].
objects
.
push_back
(
obj
);
}
void
set_pyobj
(
FuncGraphPtr
g
,
const
std
::
string
obj
)
{
graph_info_map_
[
g
].
objects
.
push_back
(
obj
);
}
...
@@ -95,11 +103,14 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
...
@@ -95,11 +103,14 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
,
std
::
vector
<
int
>
index
)
{
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
,
std
::
vector
<
int
>
index
)
{
graph_info_map_
[
g
].
obj_node_map
[
obj
]
=
std
::
make_pair
(
node
,
index
);
graph_info_map_
[
g
].
obj_node_map
[
obj
]
=
std
::
make_pair
(
node
,
index
);
}
}
CNodePtr
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
,
const
py
::
tuple
&
out
);
AnfNodePtr
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
std
::
vector
<
bool
>
*
op_masks
,
abstract
::
AbstractBasePtrList
*
args_spec_list
);
void
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
object
&
out
,
const
AnfNodePtr
&
cnode
);
ValuePtr
GetForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
);
ValuePtr
GetForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
);
void
SaveOpForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
,
const
ValuePtr
&
value
);
void
SaveOpForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
,
const
ValuePtr
&
value
);
void
SaveForwardResult
(
const
CNodePtr
&
cnode
,
const
py
::
object
&
out
);
void
SaveForwardResult
(
const
CNodePtr
&
cnode
,
const
py
::
object
&
out
);
void
SaveAllResult
(
const
OpExecInfoPtr
&
op_exec_info
,
const
CNodePtr
&
cnode
,
const
py
::
tuple
&
out
);
void
SaveAllResult
(
const
OpExecInfoPtr
&
op_exec_info
,
const
CNodePtr
&
cnode
,
const
py
::
tuple
&
out
);
py
::
object
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
);
py
::
object
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
);
void
Pushp
();
void
Pushp
();
...
@@ -108,6 +119,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
...
@@ -108,6 +119,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
size_t
arg_size
);
size_t
arg_size
);
void
SetTupleOutput
(
const
py
::
object
&
obj
,
const
AnfNodePtr
&
cnode
,
std
::
vector
<
int
>
idx
);
void
SetTupleOutput
(
const
py
::
object
&
obj
,
const
AnfNodePtr
&
cnode
,
std
::
vector
<
int
>
idx
);
AnfNodePtr
MakeValueNode
(
const
py
::
object
&
obj
,
const
std
::
string
&
obj_id
);
AnfNodePtr
MakeValueNode
(
const
py
::
object
&
obj
,
const
std
::
string
&
obj_id
);
py
::
tuple
RunOpInner
(
const
py
::
args
&
args
);
py
::
tuple
RunOpInner
(
const
OpExecInfoPtr
&
op_exec_info
);
~
PynativeExecutor
();
~
PynativeExecutor
();
...
@@ -123,10 +136,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
...
@@ -123,10 +136,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std
::
unordered_map
<
FuncGraphPtr
,
GraphInfo
>
graph_info_map_
;
std
::
unordered_map
<
FuncGraphPtr
,
GraphInfo
>
graph_info_map_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
op_forward_map_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
op_forward_map_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
op_id_map_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
op_id_map_
;
std
::
unordered_map
<
std
::
string
,
abstract
::
AbstractBasePtr
>
node_abs_map_
;
std
::
stack
<
FuncGraphPtr
>
graph_p_
;
std
::
stack
<
FuncGraphPtr
>
graph_p_
;
FuncGraphPtr
top_g_
;
FuncGraphPtr
top_g_
;
FuncGraphPtr
df_builder_
;
FuncGraphPtr
df_builder_
;
FuncGraphPtr
curr_g_
;
FuncGraphPtr
curr_g_
;
std
::
unordered_map
<
std
::
string
,
AbstractListMap
>
prim_abs_list
;
};
};
using
PynativeExecutorPtr
=
std
::
shared_ptr
<
PynativeExecutor
>
;
using
PynativeExecutorPtr
=
std
::
shared_ptr
<
PynativeExecutor
>
;
...
...
mindspore/ccsrc/utils/primitive_py.cc
浏览文件 @
a0956538
...
@@ -220,6 +220,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
...
@@ -220,6 +220,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.
def
(
"add_attr"
,
&
PrimitivePy
::
AddPyAttr
,
"add primitive attr"
)
.
def
(
"add_attr"
,
&
PrimitivePy
::
AddPyAttr
,
"add primitive attr"
)
.
def
(
"get_attr_dict"
,
&
PrimitivePy
::
GetAttrDict
,
"get primitive attr"
)
.
def
(
"get_attr_dict"
,
&
PrimitivePy
::
GetAttrDict
,
"get primitive attr"
)
.
def
(
"set_prim_type"
,
&
PrimitivePy
::
set_prim_type
,
"Set primitive type."
)
.
def
(
"set_prim_type"
,
&
PrimitivePy
::
set_prim_type
,
"Set primitive type."
)
.
def
(
"set_is_const_value"
,
&
PrimitivePy
::
set_is_const_value
,
"Set primitive is const value."
)
.
def
(
"set_signatures"
,
&
PrimitivePy
::
set_signatures
,
"Set primitive inputs signature."
)
.
def
(
"set_signatures"
,
&
PrimitivePy
::
set_signatures
,
"Set primitive inputs signature."
)
.
def
(
"register_hook"
,
&
PrimitivePy
::
set_hook
,
"Set primitive hook function."
)
.
def
(
"register_hook"
,
&
PrimitivePy
::
set_hook
,
"Set primitive hook function."
)
.
def
(
"set_instance_name"
,
&
PrimitivePy
::
set_instance_name
,
"Set primitive instance name."
);
.
def
(
"set_instance_name"
,
&
PrimitivePy
::
set_instance_name
,
"Set primitive instance name."
);
...
...
mindspore/core/ir/primitive.cc
浏览文件 @
a0956538
...
@@ -21,6 +21,31 @@
...
@@ -21,6 +21,31 @@
namespace
mindspore
{
namespace
mindspore
{
static
std
::
string
MakeId
()
{
// Use atomic to make id generator thread safe.
static
std
::
atomic
<
uint64_t
>
last_id
{
1
};
return
"P"
+
std
::
to_string
(
last_id
.
fetch_add
(
1
,
std
::
memory_order_relaxed
));
}
Primitive
::
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
,
const
PrimType
prim_type
)
:
Named
(
name
),
is_base_
(
is_base
),
has_signature_
(
false
),
prim_type_
(
prim_type
),
record_evaluate_add_attr_
(
false
),
is_const_value_
(
false
),
id_
(
MakeId
())
{}
Primitive
::
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
instance_name_
(
prim
.
instance_name_
),
is_base_
(
prim
.
is_base_
),
has_signature_
(
prim
.
has_signature_
),
prim_type_
(
prim
.
prim_type_
),
record_evaluate_add_attr_
(
false
),
id_
(
prim
.
id_
)
{}
abstract
::
AbstractBasePtr
Primitive
::
ToAbstract
()
{
abstract
::
AbstractBasePtr
Primitive
::
ToAbstract
()
{
return
std
::
make_shared
<
abstract
::
PrimitiveAbstractClosure
>
(
shared_from_base
<
Primitive
>
(),
nullptr
);
return
std
::
make_shared
<
abstract
::
PrimitiveAbstractClosure
>
(
shared_from_base
<
Primitive
>
(),
nullptr
);
}
}
...
...
mindspore/core/ir/primitive.h
浏览文件 @
a0956538
...
@@ -40,22 +40,8 @@ enum PrimType {
...
@@ -40,22 +40,8 @@ enum PrimType {
class
Primitive
:
public
Named
{
class
Primitive
:
public
Named
{
public:
public:
explicit
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
=
true
,
const
PrimType
prim_type
=
kPrimTypeBuiltIn
)
explicit
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
=
true
,
const
PrimType
prim_type
=
kPrimTypeBuiltIn
);
:
Named
(
name
),
Primitive
(
const
Primitive
&
prim
);
is_base_
(
is_base
),
has_signature_
(
false
),
prim_type_
(
prim_type
),
record_evaluate_add_attr_
(
false
)
{}
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
instance_name_
(
prim
.
instance_name_
),
is_base_
(
prim
.
is_base_
),
has_signature_
(
prim
.
has_signature_
),
prim_type_
(
prim
.
prim_type_
),
record_evaluate_add_attr_
(
false
)
{}
MS_DECLARE_PARENT
(
Primitive
,
Named
);
MS_DECLARE_PARENT
(
Primitive
,
Named
);
abstract
::
AbstractBasePtr
ToAbstract
();
abstract
::
AbstractBasePtr
ToAbstract
();
abstract
::
AbstractBasePtr
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
);
abstract
::
AbstractBasePtr
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
);
...
@@ -91,6 +77,12 @@ class Primitive : public Named {
...
@@ -91,6 +77,12 @@ class Primitive : public Named {
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
evaluate_added_attrs
()
const
{
return
evaluate_added_attrs_
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
evaluate_added_attrs
()
const
{
return
evaluate_added_attrs_
;
}
void
set_evaluate_added_attrs
(
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
)
{
for
(
auto
&
attr
:
attrs
)
{
MS_LOG
(
INFO
)
<<
" set evalu attrl "
<<
name
()
<<
attr
.
first
;
attrs_
[
attr
.
first
]
=
attr
.
second
;
}
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool
HasAttr
()
const
{
return
!
attrs_
.
empty
();
}
bool
HasAttr
()
const
{
return
!
attrs_
.
empty
();
}
...
@@ -117,6 +109,9 @@ class Primitive : public Named {
...
@@ -117,6 +109,9 @@ class Primitive : public Named {
bool
is_base
()
const
{
return
is_base_
;
}
bool
is_base
()
const
{
return
is_base_
;
}
virtual
BaseRef
RunHookFunction
(
const
VectorRef
&
args
)
const
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
virtual
BaseRef
RunHookFunction
(
const
VectorRef
&
args
)
const
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
virtual
void
CopyHookFunction
(
const
PrimitivePtr
&
primitive
)
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
virtual
void
CopyHookFunction
(
const
PrimitivePtr
&
primitive
)
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
void
set_is_const_value
(
bool
value
)
{
is_const_value_
=
value
;
}
bool
is_const_value
()
const
{
return
is_const_value_
;
}
std
::
string
id
()
const
{
return
id_
;
}
protected:
protected:
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
...
@@ -128,6 +123,8 @@ class Primitive : public Named {
...
@@ -128,6 +123,8 @@ class Primitive : public Named {
bool
has_signature_
;
bool
has_signature_
;
PrimType
prim_type_
;
PrimType
prim_type_
;
bool
record_evaluate_add_attr_
;
bool
record_evaluate_add_attr_
;
bool
is_const_value_
;
std
::
string
id_
{
""
};
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
PrimitivePtr
&
p
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
PrimitivePtr
&
p
)
{
...
...
mindspore/core/utils/profile.cc
浏览文件 @
a0956538
...
@@ -335,7 +335,7 @@ static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, c
...
@@ -335,7 +335,7 @@ static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, c
void
MsProfile
::
Print
()
{
void
MsProfile
::
Print
()
{
GetProfile
()
->
Print
();
GetProfile
()
->
Print
();
std
::
vector
<
std
::
string
>
items
=
{
"substitution."
,
"renormalize."
,
"replace."
,
"match."
,
std
::
vector
<
std
::
string
>
items
=
{
"substitution."
,
"renormalize."
,
"replace."
,
"match."
,
"func_graph_cloner_run."
,
"meta_graph."
,
"manager."
};
"func_graph_cloner_run."
,
"meta_graph."
,
"manager."
,
"pynative"
};
std
::
vector
<
TimeInfoGroup
>
groups
(
items
.
size
()
+
1
);
std
::
vector
<
TimeInfoGroup
>
groups
(
items
.
size
()
+
1
);
const
auto
&
stat
=
GetSingleton
().
time_stat_
;
const
auto
&
stat
=
GetSingleton
().
time_stat_
;
// group all time infos
// group all time infos
...
...
mindspore/ops/functional.py
浏览文件 @
a0956538
...
@@ -28,7 +28,7 @@ hastype = Primitive('hastype')
...
@@ -28,7 +28,7 @@ hastype = Primitive('hastype')
cast
=
P
.
Cast
()
cast
=
P
.
Cast
()
dtype
=
P
.
DType
()
dtype
=
P
.
DType
()
isconstant
=
Primitive
(
'is_constant'
)
isconstant
=
Primitive
(
'is_constant'
)
isconstant
.
add_prim_attr
(
'const_value'
,
True
)
isconstant
.
set_is_const_value
(
True
)
issubclass_
=
P
.
IsSubClass
()
issubclass_
=
P
.
IsSubClass
()
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
a0956538
...
@@ -1027,7 +1027,7 @@ class InvertPermutation(PrimitiveWithInfer):
...
@@ -1027,7 +1027,7 @@ class InvertPermutation(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
):
def
__init__
(
self
):
"""init InvertPermutation"""
"""init InvertPermutation"""
self
.
const_value
=
True
self
.
set_is_const_value
(
True
)
def
__infer__
(
self
,
x
):
def
__infer__
(
self
,
x
):
x_shp
=
x
[
'shape'
]
x_shp
=
x
[
'shape'
]
...
...
mindspore/ops/primitive.py
浏览文件 @
a0956538
...
@@ -352,7 +352,7 @@ def constexpr(fn=None, get_instance=True, name=None):
...
@@ -352,7 +352,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def
__init__
(
self
):
def
__init__
(
self
):
op_name
=
name
if
name
else
fn
.
__name__
op_name
=
name
if
name
else
fn
.
__name__
PrimitiveWithInfer
.
__init__
(
self
,
op_name
)
PrimitiveWithInfer
.
__init__
(
self
,
op_name
)
self
.
const_value
=
True
self
.
set_is_const_value
(
True
)
def
infer_value
(
self
,
*
args
):
def
infer_value
(
self
,
*
args
):
return
fn
(
*
args
)
return
fn
(
*
args
)
...
...
tests/ut/cpp/pynative/pynative_execute_test.cc
浏览文件 @
a0956538
...
@@ -65,27 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
...
@@ -65,27 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
py
::
none
py_none
;
py
::
none
py_none
;
py
::
args
args
=
py
::
make_tuple
(
conv_obj
,
op_name
,
op_inputs
);
py
::
args
args
=
py
::
make_tuple
(
conv_obj
,
op_name
,
op_inputs
);
py
::
list
args_input
=
args
[
PY_INPUTS
];
py
::
list
args_input
=
args
[
PY_INPUTS
];
return
GenerateOpExecInfo
(
args
,
&
args_input
);
return
GenerateOpExecInfo
(
args
);
}
TEST_F
(
TestPynativeExecute
,
TestRunOpInVM
)
{
py
::
tuple
result
;
PynativeStatusCode
status
;
auto
op_exec_info_ptr
=
ConstructOpExecInfo
();
result
=
pynative
::
RunOpInVM
(
op_exec_info_ptr
,
&
status
);
ASSERT_EQ
(
status
,
PYNATIVE_SUCCESS
);
}
TEST_F
(
TestPynativeExecute
,
TestRunOp
)
{
py
::
none
py_none
;
auto
op_exec_info_ptr
=
ConstructOpExecInfo
();
py
::
tuple
outputs
=
pynative
::
RunOp
(
py
::
make_tuple
(
op_exec_info_ptr
->
py_primitive
,
op_exec_info_ptr
->
op_name
,
op_exec_info_ptr
->
op_inputs
));
if
(
outputs
.
size
()
==
0
)
{
FAIL
();
}
else
{
SUCCEED
();
}
}
}
TEST_F
(
TestPynativeExecute
,
TestCreateContext
)
{
TEST_F
(
TestPynativeExecute
,
TestCreateContext
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录