Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4f6e63fc
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4f6e63fc
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!4576 Support if by if not inline
Merge pull request !4576 from amongo/SupportIfByIfNotInline
上级
bd955c75
1bd9fefd
变更
5
展开全部
显示空白变更内容
内联
并排
Showing
5 changed file
with
1462 addition
and
29 deletion
+1462
-29
mindspore/ccsrc/frontend/optimizer/irpass/inline.h
mindspore/ccsrc/frontend/optimizer/irpass/inline.h
+93
-5
mindspore/ccsrc/pipeline/jit/parse/parse.cc
mindspore/ccsrc/pipeline/jit/parse/parse.cc
+6
-0
mindspore/core/ir/func_graph.h
mindspore/core/ir/func_graph.h
+1
-0
tests/st/control/test_cont_grad.py
tests/st/control/test_cont_grad.py
+356
-24
tests/ut/python/pynative_mode/test_cont_cases.py
tests/ut/python/pynative_mode/test_cont_cases.py
+1006
-0
未找到文件。
mindspore/ccsrc/frontend/optimizer/irpass/inline.h
浏览文件 @
4f6e63fc
...
...
@@ -20,12 +20,14 @@
#include <vector>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/tensor.h"
#include "frontend/operator/ops.h"
namespace
mindspore
{
...
...
@@ -153,23 +155,31 @@ class InlinerBase : public AnfVisitor {
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
param
s
;
(
void
)
std
::
copy
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
param
s
));
std
::
vector
<
AnfNodePtr
>
arg
s
;
(
void
)
std
::
copy
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
arg
s
));
// compare size to avoid the case that the function has default value after grad.
// for which after renormalize, the function default value will be an input
if
(
fg
->
parameters
().
size
()
!=
param
s
.
size
())
{
if
(
fg
->
parameters
().
size
()
!=
arg
s
.
size
())
{
return
nullptr
;
}
// Not to inline after block if it has switch call inside, to avoid switch expansion.
if
(
fg
->
has_flag
(
FUNC_GRAPH_FLAG_AFTER_BLOCK
))
{
auto
has_branch_call
=
GraphHasBranch
(
fg
);
if
(
has_branch_call
)
{
return
TransformBranchCall
(
fg
,
node
,
args
);
}
}
if
(
use_move_
&&
IsUniqueUse
(
fg
,
nullptr
))
{
auto
mng
=
fg
->
manager
();
MS_EXCEPTION_IF_NULL
(
mng
);
ReplaceParams
(
mng
,
param
s
,
fg
);
ReplaceParams
(
mng
,
arg
s
,
fg
);
auto
out_node
=
fg
->
output
();
mng
->
MoveAllCNodeDropGraph
(
fg
,
node
->
func_graph
(),
inputs
[
0
]
->
scope
());
return
out_node
;
}
return
InlineClone
(
fg
,
node
->
func_graph
(),
param
s
,
inputs
[
0
]
->
scope
());
return
InlineClone
(
fg
,
node
->
func_graph
(),
arg
s
,
inputs
[
0
]
->
scope
());
}
void
ReplaceParams
(
const
FuncGraphManagerPtr
&
mng
,
const
std
::
vector
<
AnfNodePtr
>
&
new_params
,
...
...
@@ -197,11 +207,89 @@ class InlinerBase : public AnfVisitor {
is_checked_
=
false
;
is_recursive_
=
false
;
}
// For after block which contains branch call, delete the parameters which is not used.
// In most cases, it may be a `Module` or other constant input.
AnfNodePtr
TransformBranchCall
(
const
FuncGraphPtr
&
fg
,
const
AnfNodePtr
&
node
,
const
std
::
vector
<
AnfNodePtr
>
&
args
)
{
auto
&
fg_params
=
fg
->
parameters
();
std
::
vector
<
int
>
used_param_index
;
auto
mng
=
fg
->
manager
();
for
(
size_t
i
=
0
;
i
<
fg_params
.
size
();
i
++
)
{
if
(
mng
->
node_users
()[
fg_params
[
i
]].
size
()
!=
0
)
{
used_param_index
.
emplace_back
(
i
);
}
}
if
(
used_param_index
.
size
()
!=
fg_params
.
size
())
{
MS_LOG
(
DEBUG
)
<<
"Parameter not used found for graph :"
<<
fg
->
ToString
();
// clone a new graph and ignore the not used parameters
FuncGraphPtr
new_fg
=
TransformableClone
(
fg
);
auto
&
new_fg_params
=
new_fg
->
parameters
();
std
::
vector
<
AnfNodePtr
>
new_params
;
std
::
transform
(
used_param_index
.
begin
(),
used_param_index
.
end
(),
std
::
back_inserter
(
new_params
),
[
&
new_fg_params
](
size_t
i
)
{
return
new_fg_params
[
i
];
});
new_fg
->
set_parameters
(
new_params
);
std
::
vector
<
AnfNodePtr
>
node_inputs
;
node_inputs
.
push_back
(
NewValueNode
(
new_fg
));
std
::
transform
(
used_param_index
.
begin
(),
used_param_index
.
end
(),
std
::
back_inserter
(
node_inputs
),
[
&
args
](
size_t
i
)
{
return
args
[
i
];
});
return
node
->
func_graph
()
->
NewCNode
(
node_inputs
);
}
return
nullptr
;
}
// This is a try-best algorithm to find a graph which may generate branch call.
// It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
bool
GraphHasBranch
(
FuncGraphPtr
fg
)
{
if
(
graph_branch_cache_
.
find
(
fg
)
!=
graph_branch_cache_
.
end
())
{
return
graph_branch_cache_
[
fg
];
}
bool
has_branch
=
false
;
auto
nodes
=
fg
->
nodes
();
for
(
auto
&
item
:
nodes
)
{
if
(
IsPrimitiveCNode
(
item
,
prim
::
kPrimSwitch
))
{
auto
sw_inputs
=
item
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
sw_inputs
.
size
()
!=
4
)
{
MS_LOG
(
EXCEPTION
)
<<
"switch inputs should be 4"
;
}
if
(
!
sw_inputs
[
1
]
->
isa
<
ValueNode
>
()
||
IsValueNode
<
tensor
::
Tensor
>
(
sw_inputs
[
1
]))
{
has_branch
=
true
;
break
;
}
}
else
if
(
IsCNodeGraph
(
item
))
{
auto
cinputs
=
item
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
cinputs
.
size
()
<
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"graph call inputs should greater than 1"
;
}
FuncGraphPtr
call_fg
=
GetValueNode
<
FuncGraphPtr
>
(
cinputs
[
0
]);
bool
call_fg_has_branch
=
GraphHasBranch
(
call_fg
);
if
(
call_fg_has_branch
)
{
has_branch
=
true
;
break
;
}
}
else
if
(
IsPrimitiveCNode
(
item
,
prim
::
kPrimPartial
))
{
auto
cinputs
=
item
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
cinputs
.
size
()
<
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"partial call inputs should greater than 2"
;
}
FuncGraphPtr
call_fg
=
GetValueNode
<
FuncGraphPtr
>
(
cinputs
[
1
]);
if
(
call_fg
==
nullptr
)
{
continue
;
}
bool
call_fg_has_branch
=
GraphHasBranch
(
call_fg
);
if
(
call_fg_has_branch
)
{
has_branch
=
true
;
break
;
}
}
}
graph_branch_cache_
[
fg
]
=
has_branch
;
return
has_branch
;
}
private:
bool
is_checked_
{
false
},
is_recursive_
{
false
};
bool
use_move_
;
std
::
vector
<
std
::
pair
<
CriterionFuncType
,
bool
>>
criterions_
;
std
::
unordered_map
<
FuncGraphPtr
,
bool
>
graph_branch_cache_
;
};
class
Inliner
:
public
InlinerBase
{
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/pipeline/jit/parse/parse.cc
浏览文件 @
4f6e63fc
...
...
@@ -1029,6 +1029,12 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
FunctionBlockPtr
after_block
=
MakeFunctionBlock
(
*
this
);
TraceManager
::
EndTrace
();
if
(
MsContext
::
GetInstance
()
->
backend_policy
()
!=
"ge"
)
{
// for backends excludes 'ge', it can handle multi graph call, use this flag to
// generate call not inline `after_block` graph to reduce if by if switch expansion.
after_block
->
func_graph
()
->
set_flag
(
FUNC_GRAPH_FLAG_AFTER_BLOCK
,
true
);
}
// process the if-true branch
py
::
object
bodyNode
=
python_adapter
::
GetPyObjAttr
(
node
,
"body"
);
FunctionBlockPtr
true_end
=
ParseStatements
(
true_block
,
bodyNode
);
...
...
This diff is collapsed.
Click to expand it.
mindspore/core/ir/func_graph.h
浏览文件 @
4f6e63fc
...
...
@@ -74,6 +74,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
const
char
FUNC_GRAPH_FLAG_IGNORE_VALUES
[]
=
"ignore_values"
;
const
char
FUNC_GRAPH_FLAG_DEFER_INLINE
[]
=
"defer_inline"
;
const
char
FUNC_GRAPH_FLAG_AFTER_BLOCK
[]
=
"after_block"
;
const
char
FUNC_GRAPH_FLAG_CORE
[]
=
"core"
;
const
char
FUNC_GRAPH_ATTR_GRAPH_KERNEL
[]
=
"graph_kernel"
;
const
char
FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER
[]
=
"spec_param"
;
...
...
This diff is collapsed.
Click to expand it.
tests/st/control/test_cont_grad.py
浏览文件 @
4f6e63fc
...
...
@@ -42,7 +42,7 @@ def test_while_forward():
idx
=
idx
+
1
return
x
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
net
=
MyWhileNet
()
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
end
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
int32
)
...
...
@@ -72,7 +72,7 @@ def test_while_grad():
def
construct
(
self
,
*
inputs
):
return
C
.
grad_all
(
self
.
net
)(
*
inputs
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -99,7 +99,7 @@ def test_while_with_param_forward():
idx
=
idx
+
1
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
net
=
MyWhileNet
()
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
end
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
int32
)
...
...
@@ -124,7 +124,7 @@ def test_while_endless_case():
idx
=
idx
+
1
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
net
=
MyWhileNet
()
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
end
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
int32
)
...
...
@@ -159,7 +159,7 @@ def test_while_with_param_grad():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -187,7 +187,7 @@ def test_while_with_param_forward_with_const_branch():
idx
=
idx
+
1
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
while_net
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -224,7 +224,7 @@ def test_while_opt_endless():
def
construct
(
self
,
*
inputs
):
return
C
.
grad_all
(
self
.
net
)(
*
inputs
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -250,7 +250,7 @@ def test_no_while_call():
out
=
out
+
idx
+
self
.
param
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
while_net
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -287,7 +287,7 @@ def test_while_with_param_grad_with_const_branch():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -327,7 +327,7 @@ def test_for_while_with_param_grad_with_const_branch():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -364,7 +364,7 @@ def test_for_while_with_param_grad_basic():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -401,7 +401,7 @@ def test_for_while_with_param_grad_normal():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -435,7 +435,7 @@ def test_while_with_param_basic_grad():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -469,7 +469,7 @@ def test_while_with_param_basic_grad_mul():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -504,7 +504,7 @@ def test_while_with_param_basic_grad_two():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -540,7 +540,7 @@ def test_while_with_param_basic_grad_three():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -577,7 +577,7 @@ def test_while_if_with_param_grad():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -610,7 +610,7 @@ def test_while_with_param_grad_not_enter_while():
def
construct
(
self
,
a
,
b
,
c
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
a
,
b
,
c
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
while_net
=
MyWhileNet
()
net
=
GradNet
(
while_net
)
idx
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
int32
)
...
...
@@ -639,7 +639,7 @@ def test_with_param_if_by_if_forward():
out
=
out
+
x
*
2
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -672,7 +672,7 @@ def test_with_param_if_by_if_grad_inputs():
def
construct
(
self
,
*
inputs
):
return
C
.
grad_all
(
self
.
net
)(
*
inputs
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
GradNet
(
if_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -706,7 +706,7 @@ def test_with_param_if_by_if_grad_parameter():
def
construct
(
self
,
*
inputs
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
*
inputs
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
GradNet
(
if_net
)
idx
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
int32
)
...
...
@@ -738,7 +738,7 @@ def test_with_param_if_by_if_grad_param_excute_null():
def
construct
(
self
,
*
inputs
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
*
inputs
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
GradNet
(
if_net
)
idx
=
Tensor
(
np
.
array
(
4
),
dtype
=
ms
.
int32
)
...
...
@@ -772,7 +772,7 @@ def test_if_by_if_return_inside_grad():
def
construct
(
self
,
*
inputs
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
*
inputs
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
GradNet
(
if_net
)
idx
=
Tensor
(
np
.
array
(
1
),
dtype
=
ms
.
int32
)
...
...
@@ -807,10 +807,342 @@ def test_if_by_if_forward():
out
=
a
+
b
+
x
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
float32
)
end
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
array
(
4
),
dtype
=
ms
.
float32
)
net
(
idx
,
end
,
x
)
def
test_if_by_if_forward_control_tuple_switch
():
"""tuple_get from swtich op will generate new switch inside to eliminate tuple_get"""
class
Branch3Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
def
construct
(
self
,
a
,
b
,
x
):
if
b
==
x
:
b
=
self
.
add
(
a
,
b
)
else
:
b
=
self
.
add
(
a
,
x
)
return
a
,
b
,
x
class
Branch2Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
self
.
net
=
Branch3Net
()
def
construct
(
self
,
a
,
b
,
x
):
if
a
==
x
:
a
=
self
.
mul
(
a
,
b
)
else
:
a
=
self
.
div
(
a
,
b
)
return
self
.
net
(
a
,
b
,
x
)
class
MyIfByIfNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
self
.
net
=
Branch2Net
()
def
construct
(
self
,
a
,
b
,
x
):
if
a
<
b
:
a
=
self
.
add
(
a
,
b
)
else
:
a
=
self
.
sub
(
a
,
b
)
a
,
b
,
x
=
self
.
net
(
a
,
b
,
x
)
a
=
a
*
b
out
=
a
+
b
+
x
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
float32
)
end
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
float32
)
net
(
idx
,
end
,
x
)
def
test_if_by_if_forward_control_inside_net
():
class
Branch3Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
def
construct
(
self
,
a
,
b
,
x
):
if
b
==
x
:
b
=
self
.
add
(
a
,
b
)
else
:
b
=
self
.
add
(
a
,
x
)
a
=
a
*
b
out
=
a
+
b
+
x
return
out
class
Branch2Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
self
.
net
=
Branch3Net
()
def
construct
(
self
,
a
,
b
,
x
):
if
a
==
x
:
a
=
self
.
mul
(
a
,
b
)
else
:
a
=
self
.
div
(
a
,
b
)
return
self
.
net
(
a
,
b
,
x
)
class
MyIfByIfNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
self
.
net
=
Branch2Net
()
def
construct
(
self
,
a
,
b
,
x
):
if
a
<
b
:
a
=
self
.
add
(
a
,
b
)
else
:
a
=
self
.
sub
(
a
,
b
)
out
=
self
.
net
(
a
,
b
,
x
)
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
float32
)
end
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
float32
)
net
(
idx
,
end
,
x
)
def
test_if_by_if_forward_use_namespace
():
class
MyIfByIfNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
def
construct
(
self
,
a
,
b
,
x
):
if
a
<
b
:
a
=
P
.
TensorAdd
()(
a
,
b
)
else
:
a
=
P
.
Sub
()(
a
,
b
)
if
a
==
x
:
a
=
P
.
Mul
()(
a
,
b
)
else
:
a
=
P
.
RealDiv
()(
a
,
b
)
if
b
==
x
:
b
=
P
.
TensorAdd
()(
a
,
b
)
else
:
b
=
P
.
TensorAdd
()(
a
,
x
)
a
=
a
*
b
out
=
a
+
b
+
x
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
float32
)
end
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
float32
)
net
(
idx
,
end
,
x
)
def
test_if_by_if_forward_use_global_op
():
class
MyIfByIfNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
def
construct
(
self
,
a
,
b
,
x
):
add
=
P
.
TensorAdd
()
sub
=
P
.
Sub
()
mul
=
P
.
Mul
()
div
=
P
.
RealDiv
()
if
a
<
b
:
a
=
add
(
a
,
b
)
else
:
a
=
sub
(
a
,
b
)
if
a
==
x
:
a
=
mul
(
a
,
b
)
else
:
a
=
div
(
a
,
b
)
if
b
==
x
:
b
=
add
(
a
,
b
)
else
:
b
=
add
(
a
,
x
)
a
=
a
*
b
out
=
a
+
b
+
x
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
float32
)
end
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
float32
)
net
(
idx
,
end
,
x
)
def
test_for_with_if_by_if_forward
():
class
MyIfByIfNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
def
construct
(
self
,
a
,
b
,
x
):
for
_
in
range
(
0
,
4
):
if
a
<
b
:
a
=
self
.
add
(
a
,
b
)
else
:
b
=
self
.
sub
(
b
,
x
)
a
=
a
*
b
out
=
a
+
b
+
x
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
float32
)
end
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
float32
)
net
(
idx
,
end
,
x
)
def
test_for_with_if_by_if_forward_namespace
():
class
MyIfByIfNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
def
construct
(
self
,
a
,
b
,
x
):
for
_
in
range
(
0
,
6
):
if
a
<
b
:
a
=
P
.
TensorAdd
()(
a
,
b
)
else
:
b
=
P
.
Sub
()(
b
,
x
)
a
=
a
*
b
out
=
a
+
b
+
x
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
float32
)
end
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
float32
)
net
(
idx
,
end
,
x
)
def
test_if_by_if_forward_const_branch_inner
():
class
MyIfByIfNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
def
construct
(
self
,
a
,
b
,
x
):
add
=
P
.
TensorAdd
()
sub
=
P
.
Sub
()
mul
=
P
.
Mul
()
div
=
P
.
RealDiv
()
if
a
<
b
:
a
=
add
(
a
,
b
)
else
:
a
=
sub
(
a
,
b
)
if
2
>
1
:
a
=
mul
(
a
,
b
)
else
:
a
=
div
(
a
,
b
)
if
b
==
x
:
b
=
add
(
a
,
b
)
else
:
b
=
add
(
a
,
x
)
a
=
a
*
b
out
=
a
+
b
+
x
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
float32
)
end
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
float32
)
net
(
idx
,
end
,
x
)
def
test_if_by_if_forward_all_const_branch
():
class
MyIfByIfNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
sub
=
P
.
Sub
()
self
.
mul
=
P
.
Mul
()
self
.
div
=
P
.
RealDiv
()
def
construct
(
self
,
a
,
b
,
x
):
add
=
P
.
TensorAdd
()
sub
=
P
.
Sub
()
mul
=
P
.
Mul
()
div
=
P
.
RealDiv
()
if
2
<
12
:
a
=
add
(
a
,
b
)
else
:
a
=
sub
(
a
,
b
)
if
2
>
1
:
a
=
mul
(
a
,
b
)
else
:
a
=
div
(
a
,
b
)
if
2
==
1
:
b
=
add
(
a
,
b
)
else
:
b
=
add
(
a
,
x
)
a
=
a
*
b
out
=
a
+
b
+
x
return
out
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
if_net
=
MyIfByIfNet
()
net
=
if_net
idx
=
Tensor
(
np
.
array
(
2
),
dtype
=
ms
.
float32
)
end
=
Tensor
(
np
.
array
(
3
),
dtype
=
ms
.
float32
)
x
=
Tensor
(
np
.
array
(
0
),
dtype
=
ms
.
float32
)
net
(
idx
,
end
,
x
)
This diff is collapsed.
Click to expand it.
tests/ut/python/pynative_mode/test_cont_cases.py
0 → 100644
浏览文件 @
4f6e63fc
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部