Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6566b383
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6566b383
编写于
7月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3033 decoupling primitive of compute function
Merge pull request !3033 from lianliguang/primi-decoupling-v2
上级
a581766b
50e2fda5
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
86 addition
and
45 deletion
+86
-45
mindspore/ccsrc/backend/session/kernel_graph.cc
mindspore/ccsrc/backend/session/kernel_graph.cc
+1
-1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+7
-5
mindspore/ccsrc/utils/primitive_utils.cc
mindspore/ccsrc/utils/primitive_utils.cc
+24
-0
mindspore/ccsrc/utils/primitive_utils.h
mindspore/ccsrc/utils/primitive_utils.h
+5
-0
mindspore/ccsrc/vm/vmimpl.cc
mindspore/ccsrc/vm/vmimpl.cc
+6
-18
mindspore/core/ir/primitive.h
mindspore/core/ir/primitive.h
+1
-0
mindspore/core/ir/primitive_py.cc
mindspore/core/ir/primitive_py.cc
+28
-8
mindspore/core/ir/primitive_py.h
mindspore/core/ir/primitive_py.h
+4
-1
tests/ut/cpp/operator/ops_test.cc
tests/ut/cpp/operator/ops_test.cc
+1
-2
tests/ut/cpp/parallel/step_parallel_test.cc
tests/ut/cpp/parallel/step_parallel_test.cc
+1
-2
tests/ut/cpp/vm/segment_runner_test.cc
tests/ut/cpp/vm/segment_runner_test.cc
+8
-8
未找到文件。
mindspore/ccsrc/backend/session/kernel_graph.cc
浏览文件 @
6566b383
...
...
@@ -307,7 +307,7 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
if
(
inputs
.
size
()
==
1
||
!
feature_map_input_indexs
.
empty
())
{
kernel_info
->
SetFeatureMapFlag
(
true
);
}
if
(
AnfAlgo
::
IsReal
CNode
Kernel
(
cnode
))
{
if
(
AnfAlgo
::
IsRealKernel
(
cnode
))
{
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapOutput
,
MakeValue
(
kernel_info
->
is_feature_map
()),
cnode
);
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapInputList
,
MakeValue
(
feature_map_input_indexs
),
cnode
);
}
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
6566b383
...
...
@@ -363,19 +363,21 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_LOG
(
INFO
)
<<
"RunOpInVM end"
;
return
std
::
move
(
result
);
}
auto
func
=
op_exec_info
->
py_primitive
->
GetComputeFunction
();
if
(
py
::
isinstance
<
py
::
none
>
(
func
))
{
MS_LOG
(
ERROR
)
<<
"VM failed to get func"
;
auto
primitive
=
op_exec_info
->
py_primitive
;
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
result
=
primitive
->
RunPyComputeFunction
(
op_exec_info
->
op_inputs
);
if
(
py
::
isinstance
<
py
::
none
>
(
result
))
{
MS_LOG
(
ERROR
)
<<
"VM got the result none, please check whether it is failed to get func"
;
*
status
=
PYNATIVE_OP_NOT_IMPLEMENTED_ERR
;
py
::
tuple
err_ret
(
0
);
return
std
::
move
(
err_ret
);
}
// execute op
py
::
tuple
result
=
py
::
make_tuple
(
func
(
*
op_exec_info
->
op_inputs
)
);
py
::
tuple
tuple_result
=
py
::
make_tuple
(
result
);
*
status
=
PYNATIVE_SUCCESS
;
MS_LOG
(
INFO
)
<<
"RunOpInVM end"
;
return
std
::
move
(
result
);
return
std
::
move
(
tuple_
result
);
}
bool
RunOpConvertConstInputToAttr
(
const
py
::
object
&
input_object
,
size_t
input_index
,
const
PrimitivePtr
&
op_prim
,
...
...
mindspore/ccsrc/utils/primitive_utils.cc
浏览文件 @
6566b383
...
...
@@ -15,6 +15,9 @@
*/
#include "utils/primitive_utils.h"
#include <memory>
#include "pipeline/jit/parse/python_adapter.h"
#include "utils/log_adapter.h"
#include "common/utils.h"
...
...
@@ -43,4 +46,25 @@ py::function GetComputeFunction(std::string name) {
py
::
object
fn
=
mod
.
attr
(
common
::
SafeCStr
(
name
));
return
fn
;
}
py
::
tuple
ConvertDatatoPyTuple
(
const
VectorRef
&
args
)
{
auto
py_args
=
py
::
tuple
(
args
.
size
());
size_t
i
=
0
;
for
(
auto
&
arg
:
args
)
{
py_args
[
i
]
=
BaseRefToPyData
(
arg
);
MS_LOG
(
DEBUG
)
<<
"arg:"
<<
i
<<
":"
<<
arg
.
ToString
();
i
++
;
}
return
py_args
;
}
BaseRef
RunComputeFunction
(
const
PrimitivePtr
&
prim
,
const
VectorRef
&
args
)
{
auto
func
=
GetComputeFunction
(
prim
->
name
());
if
(
py
::
isinstance
<
py
::
none
>
(
func
))
{
MS_LOG
(
EXCEPTION
)
<<
prim
->
name
()
<<
" 's compute function run failed, please check whether it is not implemented"
;
}
auto
py_args
=
ConvertDatatoPyTuple
(
args
);
py
::
object
obj
=
func
(
*
py_args
);
return
std
::
make_shared
<
PyObjectRef
>
(
obj
);
}
}
// namespace mindspore
mindspore/ccsrc/utils/primitive_utils.h
浏览文件 @
6566b383
...
...
@@ -19,6 +19,7 @@
#include <string>
#include "pybind11/pybind11.h"
#include "utils/base_ref.h"
namespace
py
=
pybind11
;
...
...
@@ -28,6 +29,10 @@ py::function GetBpropFunctionByObj(py::object obj);
py
::
function
GetBpropFunction
(
std
::
string
name
);
py
::
function
GetComputeFunction
(
std
::
string
name
);
BaseRef
RunComputeFunction
(
const
PrimitivePtr
&
prim
,
const
VectorRef
&
args
);
py
::
tuple
ConvertDatatoPyTuple
(
const
VectorRef
&
args
);
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
mindspore/ccsrc/vm/vmimpl.cc
浏览文件 @
6566b383
...
...
@@ -440,25 +440,13 @@ VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) {
}
BaseRef
RunOperation
(
const
PrimitivePtr
&
prim
,
const
VectorRef
&
args
)
{
PrimitivePyPtr
operation
=
dyn_cast
<
PrimitivePy
>
(
prim
);
MS_LOG
(
DEBUG
)
<<
"operation start "
<<
prim
->
name
();
auto
func
=
operation
!=
nullptr
?
operation
->
GetComputeFunction
()
:
GetComputeFunction
(
prim
->
name
());
if
(
py
::
isinstance
<
py
::
none
>
(
func
))
{
MS_LOG
(
EXCEPTION
)
<<
prim
->
name
()
<<
" 's compute function is not implemented"
;
}
py
::
tuple
py_args
=
py
::
tuple
(
args
.
size
());
MS_LOG
(
DEBUG
)
<<
"input for operation:"
;
size_t
i
=
0
;
for
(
auto
&
arg
:
args
)
{
py_args
[
i
]
=
BaseRefToPyData
(
arg
);
MS_LOG
(
DEBUG
)
<<
"arg: "
<<
i
<<
":"
;
i
++
;
}
py
::
object
obj
=
func
(
*
py_args
);
MS_LOG
(
DEBUG
)
<<
"result:"
<<
py
::
str
(
obj
);
return
obj
;
MS_EXCEPTION_IF_NULL
(
prim
);
auto
result
=
prim
->
RunComputeFunction
(
args
);
if
(
result
.
is_null
())
{
return
RunComputeFunction
(
prim
,
args
);
}
return
result
;
}
}
// namespace compile
...
...
mindspore/core/ir/primitive.h
浏览文件 @
6566b383
...
...
@@ -83,6 +83,7 @@ class Primitive : public Named {
void
set_attr
(
const
std
::
string
&
attrName
,
const
ValuePtr
&
attr
)
{
attrs_
[
attrName
]
=
attr
;
}
void
EraseAttr
(
const
std
::
string
&
attrName
)
{
(
void
)
attrs_
.
erase
(
attrName
);
}
virtual
BaseRef
RunComputeFunction
(
const
VectorRef
&
args
)
const
{
return
nullptr
;
}
ValuePtr
GetAttr
(
const
std
::
string
&
attrName
)
const
{
auto
iter
=
attrs_
.
find
(
attrName
);
...
...
mindspore/core/ir/primitive_py.cc
浏览文件 @
6566b383
...
...
@@ -79,13 +79,7 @@ py::function PrimitivePy::GetBpropFunction() {
}
BaseRef
PrimitivePy
::
RunHookFunction
(
const
VectorRef
&
args
)
const
{
auto
py_args
=
py
::
tuple
(
args
.
size
());
size_t
i
=
0
;
for
(
auto
&
arg
:
args
)
{
py_args
[
i
]
=
BaseRefToPyData
(
arg
);
MS_LOG
(
DEBUG
)
<<
"arg:"
<<
i
<<
":"
;
i
++
;
}
auto
py_args
=
ConvertDatatoPyTuple
(
args
);
py
::
object
obj
;
bool
is_bprop
=
this
->
HasAttr
(
kBpropAttrName
);
if
(
is_bprop
)
{
...
...
@@ -123,7 +117,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
return
std
::
make_shared
<
PyObjectRef
>
(
obj
);
}
py
::
function
PrimitivePy
::
GetComputeFunction
()
{
py
::
function
PrimitivePy
::
GetComputeFunction
()
const
{
static
const
char
*
const
compute_func_name
=
"vm_impl"
;
if
(
py
::
hasattr
(
python_obj_
,
compute_func_name
))
{
...
...
@@ -176,6 +170,32 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
this
->
set_hook
(
primitive_py
->
hook
());
}
BaseRef
PrimitivePy
::
RunComputeFunction
(
const
VectorRef
&
args
)
const
{
auto
py_args
=
ConvertDatatoPyTuple
(
args
);
auto
result
=
this
->
RunPyComputeFunction
(
py_args
);
if
(
py
::
isinstance
<
py
::
none
>
(
result
))
{
return
std
::
make_shared
<
BaseRef
>
(
nullptr
);
}
return
std
::
make_shared
<
PyObjectRef
>
(
result
);
}
py
::
object
PrimitivePy
::
RunPyComputeFunction
(
const
py
::
tuple
&
py_args
)
const
{
auto
func
=
this
->
GetComputeFunction
();
if
(
py
::
isinstance
<
py
::
none
>
(
func
))
{
return
py
::
none
();
}
auto
result
=
func
(
*
py_args
);
return
result
;
}
bool
PrimitivePy
::
HasComputeFunction
()
const
{
auto
func
=
GetComputeFunction
();
if
(
py
::
isinstance
<
py
::
none
>
(
func
))
{
return
false
;
}
return
true
;
}
REGISTER_PYBIND_DEFINE
(
Primitive_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
PrimType
>
(
*
m
,
"prim_type"
,
py
::
arithmetic
())
.
value
(
"unknown"
,
PrimType
::
kPrimTypeUnknown
)
...
...
mindspore/core/ir/primitive_py.h
浏览文件 @
6566b383
...
...
@@ -41,7 +41,6 @@ class PrimitivePy : public Primitive {
~
PrimitivePy
()
override
=
default
;
MS_DECLARE_PARENT
(
PrimitivePy
,
Primitive
);
py
::
function
GetBpropFunction
();
py
::
function
GetComputeFunction
();
void
set_signatures
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
SignatureEnumRW
,
SignatureEnumKind
,
py
::
object
,
SignatureEnumDType
>>
...
...
@@ -57,11 +56,15 @@ class PrimitivePy : public Primitive {
void
set_hook
(
const
py
::
function
&
hook
)
{
hook_
=
hook
;
}
py
::
function
hook
()
const
{
return
hook_
;
}
BaseRef
RunHookFunction
(
const
VectorRef
&
args
)
const
override
;
BaseRef
RunComputeFunction
(
const
VectorRef
&
args
)
const
override
;
py
::
object
RunPyComputeFunction
(
const
py
::
tuple
&
py_args
)
const
;
bool
HasComputeFunction
()
const
;
const
bool
parse_info_
=
true
;
const
py
::
object
&
GetPyObj
()
const
{
return
python_obj_
;
}
bool
is_tuple_input_
=
false
;
private:
py
::
function
GetComputeFunction
()
const
;
py
::
object
python_obj_
;
py
::
function
hook_
;
std
::
vector
<
Signature
>
signatures_
;
...
...
tests/ut/cpp/operator/ops_test.cc
浏览文件 @
6566b383
...
...
@@ -454,8 +454,7 @@ TEST_F(TestOps, GetConv2DPrimPyTest) {
ASSERT_TRUE
(
conv2d_ptr
);
if
(
nullptr
!=
conv2d_ptr
)
{
MS_LOG
(
INFO
)
<<
"Get PrimitivePyPtr: "
<<
conv2d_ptr
->
name
();
auto
func
=
conv2d_ptr
->
GetComputeFunction
();
if
(
py
::
isinstance
<
py
::
none
>
(
func
))
{
if
(
!
conv2d_ptr
->
HasComputeFunction
()){
MS_LOG
(
EXCEPTION
)
<<
""
<<
conv2d_ptr
->
name
()
<<
"'s compute function is not implemented"
;
}
...
...
tests/ut/cpp/parallel/step_parallel_test.cc
浏览文件 @
6566b383
...
...
@@ -294,8 +294,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
ASSERT_TRUE
(
allreduce_ptr
);
if
(
nullptr
!=
allreduce_ptr
)
{
MS_LOG
(
INFO
)
<<
"Get PrimitivePyPtr: "
<<
allreduce_ptr
->
name
();
auto
func
=
allreduce_ptr
->
GetComputeFunction
();
if
(
py
::
isinstance
<
py
::
none
>
(
func
))
{
if
(
!
allreduce_ptr
->
HasComputeFunction
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
allreduce_ptr
->
name
()
<<
"'s compute function is not implemented"
;
}
...
...
tests/ut/cpp/vm/segment_runner_test.cc
浏览文件 @
6566b383
...
...
@@ -57,11 +57,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
std
::
vector
<
BaseRef
>
todos
(
splits
.
size
());
auto
it
=
std
::
copy_if
(
std
::
begin
(
splits
),
std
::
end
(
splits
),
std
::
begin
(
todos
),
[](
const
BaseRef
&
seg
)
->
bool
{
return
utils
::
isa
<
VectorRef
>
(
seg
);
});
[](
const
BaseRef
&
seg
)
->
bool
{
return
utils
::
isa
<
VectorRef
>
(
seg
);
});
todos
.
resize
(
std
::
distance
(
todos
.
begin
(),
it
));
ASSERT_EQ
(
todos
.
size
(),
1
);
AnfNodePtrList
anf_list
;
AnfNodePtrList
anf_list
;
for
(
auto
&
item
:
utils
::
cast
<
VectorRef
>
(
todos
[
0
]))
{
anf_list
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
(
item
));
}
...
...
@@ -81,11 +81,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {
std
::
vector
<
BaseRef
>
todos
(
splits
.
size
());
auto
it
=
std
::
copy_if
(
std
::
begin
(
splits
),
std
::
end
(
splits
),
std
::
begin
(
todos
),
[](
const
BaseRef
&
seg
)
->
bool
{
return
utils
::
isa
<
VectorRef
>
(
seg
);
});
[](
const
BaseRef
&
seg
)
->
bool
{
return
utils
::
isa
<
VectorRef
>
(
seg
);
});
todos
.
resize
(
std
::
distance
(
todos
.
begin
(),
it
));
ASSERT_EQ
(
todos
.
size
(),
1
);
AnfNodePtrList
anf_list
;
AnfNodePtrList
anf_list
;
for
(
auto
&
item
:
utils
::
cast
<
VectorRef
>
(
todos
[
0
]))
{
anf_list
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
(
item
));
}
...
...
@@ -105,11 +105,11 @@ TEST_F(TestCompileSegmentRunner, test_if) {
std
::
vector
<
BaseRef
>
todos
(
splits
.
size
());
auto
it
=
std
::
copy_if
(
std
::
begin
(
splits
),
std
::
end
(
splits
),
std
::
begin
(
todos
),
[](
const
BaseRef
&
seg
)
->
bool
{
return
utils
::
isa
<
VectorRef
>
(
seg
);
});
[](
const
BaseRef
&
seg
)
->
bool
{
return
utils
::
isa
<
VectorRef
>
(
seg
);
});
todos
.
resize
(
std
::
distance
(
todos
.
begin
(),
it
));
ASSERT_EQ
(
todos
.
size
(),
1
);
AnfNodePtrList
anf_list
;
AnfNodePtrList
anf_list
;
for
(
auto
&
item
:
utils
::
cast
<
VectorRef
>
(
todos
[
0
]))
{
anf_list
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
(
item
));
}
...
...
@@ -122,13 +122,13 @@ TEST_F(TestCompileSegmentRunner, test_if) {
TEST_F
(
TestCompileSegmentRunner
,
test_RunOperation1
)
{
VectorRef
args
({
1
});
auto
res
=
RunOperation
(
prim
::
kPrimIdentity
,
args
);
auto
res
=
RunOperation
(
std
::
make_shared
<
PrimitivePy
>
(
py
::
str
(
prim
::
kPrimIdentity
->
name
()),
py
::
none
())
,
args
);
ASSERT_EQ
(
py
::
cast
<
int
>
(
BaseRefToPyData
(
res
)),
1
);
}
TEST_F
(
TestCompileSegmentRunner
,
test_RunOperation2
)
{
VectorRef
args
({
1
,
2
});
auto
res
=
RunOperation
(
prim
::
kPrimScalarGt
,
args
);
auto
res
=
RunOperation
(
std
::
make_shared
<
PrimitivePy
>
(
py
::
str
(
prim
::
kPrimScalarGt
->
name
()),
py
::
none
())
,
args
);
ASSERT_EQ
(
py
::
cast
<
bool
>
(
BaseRefToPyData
(
res
)),
false
);
}
}
// namespace compile
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录