Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
70abe362
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
70abe362
编写于
7月 06, 2020
作者:
L
leilei_snow
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add case process
上级
b8292613
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
103 addition
and
6 deletion
+103
-6
mindspore/ccsrc/utils/convert_utils.cc
mindspore/ccsrc/utils/convert_utils.cc
+14
-0
mindspore/ccsrc/utils/convert_utils.h
mindspore/ccsrc/utils/convert_utils.h
+1
-0
mindspore/ccsrc/vm/backend.cc
mindspore/ccsrc/vm/backend.cc
+1
-0
mindspore/ccsrc/vm/backend.h
mindspore/ccsrc/vm/backend.h
+1
-0
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+49
-2
mindspore/ccsrc/vm/transform.h
mindspore/ccsrc/vm/transform.h
+1
-0
mindspore/ccsrc/vm/vm.cc
mindspore/ccsrc/vm/vm.cc
+29
-0
mindspore/ccsrc/vm/vm.h
mindspore/ccsrc/vm/vm.h
+7
-4
未找到文件。
mindspore/ccsrc/utils/convert_utils.cc
浏览文件 @
70abe362
...
@@ -230,6 +230,20 @@ bool ValueToBool(const ValuePtr &v, bool *value) {
...
@@ -230,6 +230,20 @@ bool ValueToBool(const ValuePtr &v, bool *value) {
return
true
;
return
true
;
}
}
bool
BaseRefToInt
(
const
ValuePtr
&
v
,
int
*
value
)
{
MS_EXCEPTION_IF_NULL
(
v
);
if
(
v
->
isa
<
tensor
::
Tensor
>
())
{
auto
tensor
=
v
->
cast
<
tensor
::
TensorPtr
>
();
(
void
)
tensor
->
data_sync
();
int
*
tensor_data
=
static_cast
<
int
*>
(
tensor
->
data_c
());
auto
vb
=
tensor_data
[
0
];
*
value
=
vb
;
return
true
;
}
MS_LOG
(
ERROR
)
<<
"Index must be tensor type."
;
return
false
;
}
bool
BaseRefToBool
(
const
BaseRef
&
v
,
bool
*
value
)
{
bool
BaseRefToBool
(
const
BaseRef
&
v
,
bool
*
value
)
{
if
(
utils
::
isa
<
ValuePtr
>
(
v
))
{
if
(
utils
::
isa
<
ValuePtr
>
(
v
))
{
return
ValueToBool
(
utils
::
cast
<
ValuePtr
>
(
v
),
value
);
return
ValueToBool
(
utils
::
cast
<
ValuePtr
>
(
v
),
value
);
...
...
mindspore/ccsrc/utils/convert_utils.h
浏览文件 @
70abe362
...
@@ -42,6 +42,7 @@ using TensorPtr = std::shared_ptr<Tensor>;
...
@@ -42,6 +42,7 @@ using TensorPtr = std::shared_ptr<Tensor>;
py
::
object
AnyToPyData
(
const
Any
&
value
);
py
::
object
AnyToPyData
(
const
Any
&
value
);
py
::
object
BaseRefToPyData
(
const
BaseRef
&
value
);
py
::
object
BaseRefToPyData
(
const
BaseRef
&
value
);
bool
BaseRefToBool
(
const
BaseRef
&
in
,
bool
*
out
);
bool
BaseRefToBool
(
const
BaseRef
&
in
,
bool
*
out
);
bool
BaseRefToInt
(
const
ValuePtr
&
v
,
int
*
value
);
bool
ValueToBool
(
const
ValuePtr
&
in
,
bool
*
out
);
bool
ValueToBool
(
const
ValuePtr
&
in
,
bool
*
out
);
py
::
object
ValuePtrToPyData
(
const
ValuePtr
&
value
);
py
::
object
ValuePtrToPyData
(
const
ValuePtr
&
value
);
...
...
mindspore/ccsrc/vm/backend.cc
浏览文件 @
70abe362
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
compile
{
namespace
compile
{
bool
Backend
::
GetCond
(
const
BaseRef
&
c
,
bool
*
const
value
)
{
return
BaseRefToBool
(
c
,
value
);
}
bool
Backend
::
GetCond
(
const
BaseRef
&
c
,
bool
*
const
value
)
{
return
BaseRefToBool
(
c
,
value
);
}
bool
Backend
::
GetIndex
(
const
BaseRef
&
c
,
int
*
const
value
)
{
return
BaseRefToInt
(
utils
::
cast
<
ValuePtr
>
(
c
),
value
);
}
LinConvertResult
MsBackend
::
GetMultiGraphRun
(
const
FuncGraphPtr
&
g
)
{
LinConvertResult
MsBackend
::
GetMultiGraphRun
(
const
FuncGraphPtr
&
g
)
{
// multi_graph merge to one, big graph have paramters in begin and only have one output
// multi_graph merge to one, big graph have paramters in begin and only have one output
...
...
mindspore/ccsrc/vm/backend.h
浏览文件 @
70abe362
...
@@ -46,6 +46,7 @@ class Backend {
...
@@ -46,6 +46,7 @@ class Backend {
virtual
void
SimulateRun
(
FinalVMPtr
,
FuncGraphPtr
)
{}
virtual
void
SimulateRun
(
FinalVMPtr
,
FuncGraphPtr
)
{}
virtual
SwitchCondStatus
SetSimuCond
(
const
BaseRef
&
,
bool
)
{
return
kCondOk
;
}
virtual
SwitchCondStatus
SetSimuCond
(
const
BaseRef
&
,
bool
)
{
return
kCondOk
;
}
virtual
bool
GetCond
(
const
BaseRef
&
c
,
bool
*
value
);
virtual
bool
GetCond
(
const
BaseRef
&
c
,
bool
*
value
);
virtual
bool
GetIndex
(
const
BaseRef
&
c
,
int
*
value
);
virtual
void
SetSwitchGraph
()
{}
virtual
void
SetSwitchGraph
()
{}
virtual
void
SetSwitchActive
(
const
BaseRef
&
,
bool
)
{}
virtual
void
SetSwitchActive
(
const
BaseRef
&
,
bool
)
{}
virtual
void
RecallGraphInput
(
const
FuncGraphPtr
&
,
const
VectorRef
&
,
const
BaseRef
&
)
{}
virtual
void
RecallGraphInput
(
const
FuncGraphPtr
&
,
const
VectorRef
&
,
const
BaseRef
&
)
{}
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
70abe362
...
@@ -46,8 +46,9 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiv
...
@@ -46,8 +46,9 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiv
std
::
vector
<
PrimitivePtr
>
nonlinear_ops
=
{
prim
::
kPrimReturn
,
prim
::
kPrimPartial
,
prim
::
kPrimSwitch
,
std
::
vector
<
PrimitivePtr
>
nonlinear_ops
=
{
prim
::
kPrimReturn
,
prim
::
kPrimPartial
,
prim
::
kPrimSwitch
,
prim
::
kPrimMakeTuple
,
prim
::
kPrimBpropCut
};
prim
::
kPrimMakeTuple
,
prim
::
kPrimBpropCut
};
const
std
::
vector
<
PrimitivePtr
>
&
GetMsNonlinearOps
()
{
const
std
::
vector
<
PrimitivePtr
>
&
GetMsNonlinearOps
()
{
static
const
std
::
vector
<
PrimitivePtr
>
ms_nonlinear_ops
=
{
prim
::
kPrimReturn
,
prim
::
kPrimPartial
,
prim
::
kPrimSwitch
,
static
const
std
::
vector
<
PrimitivePtr
>
ms_nonlinear_ops
=
{
prim
::
kPrimReturn
,
prim
::
kPrimPartial
,
prim
::
kPrimBpropCut
};
prim
::
kPrimSwitch
,
prim
::
kPrimMakeTuple
,
prim
::
kPrimBpropCut
,
prim
::
kPrimSwitchLayer
};
return
ms_nonlinear_ops
;
return
ms_nonlinear_ops
;
}
}
...
@@ -187,6 +188,30 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
...
@@ -187,6 +188,30 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
std
::
reverse
(
result
.
begin
(),
result
.
end
());
std
::
reverse
(
result
.
begin
(),
result
.
end
());
return
result
;
return
result
;
}
}
bool
IsSubGraph
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
&
inputs
=
cnode
->
inputs
();
if
(
inputs
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs of apply node is empty"
;
}
AnfNodePtr
fn
=
inputs
[
0
];
MS_EXCEPTION_IF_NULL
(
fn
);
if
(
!
IsValueNode
<
Primitive
>
(
fn
))
{
return
false
;
}
auto
node_prim
=
GetValueNode
<
PrimitivePtr
>
(
fn
);
if
(
node_prim
->
name
()
==
prim
::
kPrimPartial
->
name
())
{
return
true
;
}
}
else
if
(
IsValueNode
<
FuncGraph
>
(
node
))
{
return
true
;
}
return
false
;
}
}
// namespace
}
// namespace
CompileGraph
::
CompileGraph
(
const
BackendPtr
&
backend
,
const
std
::
vector
<
PrimitivePtr
>
&
cut_list
)
CompileGraph
::
CompileGraph
(
const
BackendPtr
&
backend
,
const
std
::
vector
<
PrimitivePtr
>
&
cut_list
)
...
@@ -235,6 +260,15 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
...
@@ -235,6 +260,15 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL
(
ms_context
);
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_enable_pynative_hook
(
true
);
ms_context
->
set_enable_pynative_hook
(
true
);
}
}
if
(
backend_
->
name
()
==
kMsConvert
&&
prim
->
name
()
==
prim
::
kPrimMakeTuple
->
name
())
{
if
(
inputs
.
size
()
<
2
)
{
return
false
;
}
auto
ret
=
IsSubGraph
(
inputs
[
1
]);
return
ret
;
}
return
true
;
return
true
;
}
}
}
}
...
@@ -466,6 +500,8 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
...
@@ -466,6 +500,8 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
}
else
if
(
IsPrimitive
(
fn
,
prim
::
kPrimSwitch
))
{
}
else
if
(
IsPrimitive
(
fn
,
prim
::
kPrimSwitch
))
{
AddSwitch
(
node
);
AddSwitch
(
node
);
AddSinkSwitch
(
node
);
AddSinkSwitch
(
node
);
}
else
if
(
IsPrimitive
(
fn
,
prim
::
kPrimSwitchLayer
))
{
AddSwitchLayer
(
node
);
}
else
if
(
IsPrimitive
(
fn
,
prim
::
kPrimMakeTuple
))
{
}
else
if
(
IsPrimitive
(
fn
,
prim
::
kPrimMakeTuple
))
{
AddMakeTuple
(
node
);
AddMakeTuple
(
node
);
}
else
{
}
else
{
...
@@ -622,6 +658,17 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
...
@@ -622,6 +658,17 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
AddInst
(
Instruction
::
kSwitch
,
args
);
AddInst
(
Instruction
::
kSwitch
,
args
);
}
}
void
CompileGraph
::
AddSwitchLayer
(
const
CNodePtr
&
node
)
{
auto
inputs
=
node
->
inputs
();
if
(
inputs
.
size
()
!=
3
)
{
MS_LOG
(
EXCEPTION
)
<<
"Switch layer must have index and branches."
;
}
VectorRef
args
;
args
.
emplace_back
(
Ref
(
inputs
[
1
]));
args
.
emplace_back
(
Ref
(
inputs
[
2
]));
AddInst
(
Instruction
::
kSwitchLayer
,
args
);
}
void
CompileGraph
::
AddReturn
(
const
CNodePtr
&
node
)
{
void
CompileGraph
::
AddReturn
(
const
CNodePtr
&
node
)
{
VectorRef
args
;
VectorRef
args
;
if
(
backend_
->
simu_flag
())
{
if
(
backend_
->
simu_flag
())
{
...
...
mindspore/ccsrc/vm/transform.h
浏览文件 @
70abe362
...
@@ -90,6 +90,7 @@ class CompileGraph {
...
@@ -90,6 +90,7 @@ class CompileGraph {
void
AddPartial
(
const
CNodePtr
&
node
);
void
AddPartial
(
const
CNodePtr
&
node
);
void
AddMakeTuple
(
const
CNodePtr
&
node
);
void
AddMakeTuple
(
const
CNodePtr
&
node
);
void
AddSwitch
(
const
CNodePtr
&
node
);
void
AddSwitch
(
const
CNodePtr
&
node
);
void
AddSwitchLayer
(
const
CNodePtr
&
node
);
void
AddReturn
(
const
CNodePtr
&
node
);
void
AddReturn
(
const
CNodePtr
&
node
);
void
AddPrimitive
(
const
CNodePtr
&
node
,
const
PrimitivePtr
&
prim
);
void
AddPrimitive
(
const
CNodePtr
&
node
,
const
PrimitivePtr
&
prim
);
void
AddInput
(
const
AnfNodePtr
&
node
);
void
AddInput
(
const
AnfNodePtr
&
node
);
...
...
mindspore/ccsrc/vm/vm.cc
浏览文件 @
70abe362
...
@@ -480,6 +480,35 @@ void FinalVM::InstSwitch(const VectorRef &args) {
...
@@ -480,6 +480,35 @@ void FinalVM::InstSwitch(const VectorRef &args) {
MS_LOG
(
DEBUG
)
<<
"End"
;
MS_LOG
(
DEBUG
)
<<
"End"
;
}
}
void
FinalVM
::
InstSwitchLayer
(
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"Start"
;
const
size_t
args_size
=
2
;
if
(
args
.
size
()
!=
args_size
)
{
MS_LOG
(
ERROR
)
<<
__FUNCTION__
<<
" requires "
<<
args_size
<<
" parameters, while the input size is "
<<
args
.
size
()
<<
"."
;
return
;
}
int
idx
=
utils
::
cast
<
int
>
(
args
[
0
]);
VectorRef
branches
=
utils
::
cast
<
VectorRef
>
(
Ref
(
utils
::
cast
<
int
>
(
args
[
1
])));
int
size
=
static_cast
<
int
>
(
branches
.
size
());
BaseRef
index
=
Ref
(
idx
);
int
idx_value
=
0
;
if
(
!
backend_
->
GetIndex
(
index
,
&
idx_value
))
{
MS_LOG
(
EXCEPTION
)
<<
"Not supported type to be casted to int."
;
}
if
(
idx_value
<
0
)
{
// Add support negative index range [-size, -1].
idx_value
+=
size
;
}
if
(
idx_value
<
0
||
idx_value
>=
size
)
{
MS_LOG
(
EXCEPTION
)
<<
__FUNCTION__
<<
" given index "
<<
idx_value
<<
" out of range."
;
}
Push
(
branches
[
idx_value
]);
MS_LOG
(
DEBUG
)
<<
"End"
;
}
void
FinalVM
::
InstTuple
(
const
VectorRef
&
args
)
{
void
FinalVM
::
InstTuple
(
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"Start"
;
MS_LOG
(
DEBUG
)
<<
"Start"
;
VectorRef
tuple
;
VectorRef
tuple
;
...
...
mindspore/ccsrc/vm/vm.h
浏览文件 @
70abe362
...
@@ -51,15 +51,17 @@ enum Instruction {
...
@@ -51,15 +51,17 @@ enum Instruction {
kPush
,
kPush
,
kPrim
,
kPrim
,
kGraph
,
kGraph
,
kPadStack
kPadStack
,
kSwitchLayer
};
};
using
InstType
=
std
::
pair
<
Instruction
,
VectorRef
>
;
using
InstType
=
std
::
pair
<
Instruction
,
VectorRef
>
;
using
InstSet
=
std
::
vector
<
InstType
>
;
using
InstSet
=
std
::
vector
<
InstType
>
;
using
InstFunctionMap
=
std
::
map
<
Instruction
,
std
::
function
<
void
(
const
VectorRef
&
)
>>
;
using
InstFunctionMap
=
std
::
map
<
Instruction
,
std
::
function
<
void
(
const
VectorRef
&
)
>>
;
const
std
::
vector
<
std
::
string
>
inst_str
{
"call"
,
"tail_call"
,
"return"
,
"partial"
,
"switch"
,
"switch_return"
,
"tuple"
,
const
std
::
vector
<
std
::
string
>
inst_str
{
"call"
,
"tail_call"
,
"return"
,
"partial"
,
"switch"
,
"input"
,
"external"
,
"push"
,
"primitive"
,
"graph"
,
"pad_stack"
};
"switch_return"
,
"tuple"
,
"input"
,
"external"
,
"push"
,
"primitive"
,
"graph"
,
"pad_stack"
,
"switch_layer"
};
class
StructPartial
:
public
Base
{
class
StructPartial
:
public
Base
{
public:
public:
// Initialize StructPartial.
// Initialize StructPartial.
...
@@ -114,6 +116,7 @@ class FinalVM {
...
@@ -114,6 +116,7 @@ class FinalVM {
void
InstExternal
(
const
VectorRef
&
args
);
void
InstExternal
(
const
VectorRef
&
args
);
void
InstPushPrim
(
const
VectorRef
&
args
);
void
InstPushPrim
(
const
VectorRef
&
args
);
void
InstSwitchReturn
(
const
VectorRef
&
args
);
void
InstSwitchReturn
(
const
VectorRef
&
args
);
void
InstSwitchLayer
(
const
VectorRef
&
args
);
void
set_insts
(
const
InstSet
&
value
)
{
insts_
=
value
;
}
void
set_insts
(
const
InstSet
&
value
)
{
insts_
=
value
;
}
BaseRef
RunHook
(
const
PrimitivePtr
&
prim
,
const
VectorRef
&
arg
);
BaseRef
RunHook
(
const
PrimitivePtr
&
prim
,
const
VectorRef
&
arg
);
...
@@ -157,7 +160,7 @@ class FinalVM {
...
@@ -157,7 +160,7 @@ class FinalVM {
{
Instruction
::
kExternal
,
[
this
](
const
VectorRef
&
args
)
{
InstExternal
(
args
);
}},
{
Instruction
::
kExternal
,
[
this
](
const
VectorRef
&
args
)
{
InstExternal
(
args
);
}},
{
Instruction
::
kPrim
,
[
this
](
const
VectorRef
&
args
)
{
InstPushPrim
(
args
);
}},
{
Instruction
::
kPrim
,
[
this
](
const
VectorRef
&
args
)
{
InstPushPrim
(
args
);
}},
{
Instruction
::
kSwitchReturn
,
[
this
](
const
VectorRef
&
args
)
{
InstSwitchReturn
(
args
);
}},
{
Instruction
::
kSwitchReturn
,
[
this
](
const
VectorRef
&
args
)
{
InstSwitchReturn
(
args
);
}},
};
{
Instruction
::
kSwitchLayer
,
[
this
](
const
VectorRef
&
args
)
{
InstSwitchLayer
(
args
);
}}
};
std
::
map
<
std
::
string
,
py
::
object
>
_hook_grad
;
std
::
map
<
std
::
string
,
py
::
object
>
_hook_grad
;
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录