Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
22a9d02e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
22a9d02e
编写于
8月 18, 2020
作者:
P
panyifeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
switch_layer incorporate env_get and tuple_get
上级
4f46c427
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
406 addition
and
3 deletion
+406
-3
mindspore/ccsrc/frontend/optimizer/irpass.cc
mindspore/ccsrc/frontend/optimizer/irpass.cc
+4
-0
mindspore/ccsrc/frontend/optimizer/irpass.h
mindspore/ccsrc/frontend/optimizer/irpass.h
+1
-0
mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h
...pore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h
+135
-0
mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
...ore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
+235
-2
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+1
-1
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+30
-0
未找到文件。
mindspore/ccsrc/frontend/optimizer/irpass.cc
浏览文件 @
22a9d02e
...
...
@@ -95,6 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
incorporate_env_getitem_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateEnvGetitem
>
(),
"incorporate_env_get_item"
,
prim
::
kPrimEnvGetItem
);
incorporate_env_getitem_switch_layer_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateEnvGetitemSwitchLayer
>
(),
"incorporate_env_getitem_switch_layer"
,
prim
::
kPrimEnvGetItem
);
// Ref eliminate
make_ref_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
MakeRefEliminater
>
(),
"make_ref_eliminate"
,
prim
::
kPrimMakeRef
);
...
...
mindspore/ccsrc/frontend/optimizer/irpass.h
浏览文件 @
22a9d02e
...
...
@@ -58,6 +58,7 @@ class OptimizeIRPassLib {
SubstitutionPtr
incorporate_env_getitem_
;
SubstitutionPtr
incorporate_env_getitem_bypass_recursive_
;
SubstitutionPtr
incorporate_env_getitem_switch_
;
SubstitutionPtr
incorporate_env_getitem_switch_layer_
;
// Ref eliminate
SubstitutionPtr
make_ref_eliminate_
;
...
...
mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h
浏览文件 @
22a9d02e
...
...
@@ -91,6 +91,69 @@ class EnvGetitemTransform {
std
::
unordered_map
<
std
::
pair
<
SymbolicKeyInstancePtr
,
AnfNodePtr
>
,
FuncGraphPtr
,
PairHasher
>>
cache_
;
};
class
EnvGetitemTransformACrossGraph
{
public:
EnvGetitemTransformACrossGraph
()
:
cache_
()
{}
~
EnvGetitemTransformACrossGraph
()
=
default
;
FuncGraphPtr
operator
()(
const
FuncGraphPtr
&
fg
,
const
SymbolicKeyInstancePtr
&
key
,
const
AnfNodePtr
&
default_node
)
{
if
(
cache_
.
find
(
fg
)
==
cache_
.
end
())
{
cache_
[
fg
]
=
{};
}
auto
&
cache
=
cache_
[
fg
];
auto
hash_key
=
std
::
make_pair
(
key
,
default_node
);
if
(
cache
.
find
(
hash_key
)
==
cache
.
end
())
{
std
::
ostringstream
ss
(
"env"
,
std
::
ostringstream
::
app
);
if
(
key
->
node
()
!=
nullptr
)
{
ss
<<
key
->
node
()
->
ToString
();
}
auto
new_fg_outer
=
TransformableClone
(
fg
,
std
::
make_shared
<
TraceTransform
>
(
ss
.
str
()));
auto
output_outer
=
new_fg_outer
->
output
();
if
(
!
IsValueNode
<
FuncGraph
>
(
output_outer
))
{
MS_LOG
(
WARNING
)
<<
"Output of outer graph should be a func_graph"
;
return
nullptr
;
}
auto
fg_inner
=
GetValueNode
<
FuncGraphPtr
>
(
output_outer
);
auto
new_fg
=
TransformableClone
(
fg_inner
,
std
::
make_shared
<
TraceTransform
>
(
ss
.
str
()));
new_fg_outer
->
set_output
(
NewValueNode
(
new_fg
));
auto
env
=
new_fg
->
output
();
while
(
IsPrimitiveCNode
(
env
,
prim
::
kPrimEnvSetItem
))
{
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto
&
inputs
=
env
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
inputs
.
size
()
!=
4
)
{
MS_LOG
(
WARNING
)
<<
"Input size should be 4"
;
return
nullptr
;
}
if
(
!
IsValueNode
<
SymbolicKeyInstance
>
(
inputs
[
2
]))
{
MS_LOG
(
DEBUG
)
<<
"Input 2 is not a SymbolicKeyInstance?"
;
return
nullptr
;
}
env
=
inputs
[
1
];
auto
value
=
inputs
[
3
];
auto
key2
=
GetValueNode
<
SymbolicKeyInstancePtr
>
(
inputs
[
2
]);
if
(
*
key2
==
*
key
)
{
new_fg
->
set_output
(
value
);
cache
[
hash_key
]
=
new_fg_outer
;
return
new_fg_outer
;
}
}
new_fg
->
set_output
(
new_fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimEnvGetItem
),
env
,
NewValueNode
(
key
),
default_node
}));
cache
[
hash_key
]
=
new_fg_outer
;
}
return
cache
[
hash_key
];
}
private:
std
::
unordered_map
<
FuncGraphPtr
,
std
::
unordered_map
<
std
::
pair
<
SymbolicKeyInstancePtr
,
AnfNodePtr
>
,
FuncGraphPtr
,
PairHasher
>>
cache_
;
};
}
// namespace internal
// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
...
...
@@ -358,6 +421,78 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
bool
is_match_
{
false
};
internal
::
EnvGetitemTransform
env_get_item_transform_
;
};
// {prim::kPrimEnvGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y}
class
IncorporateEnvGetitemSwitchLayer
:
public
AnfVisitor
{
public:
IncorporateEnvGetitemSwitchLayer
()
:
env_get_item_transform_
()
{}
~
IncorporateEnvGetitemSwitchLayer
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
is_match_
=
false
;
AnfVisitor
::
Match
(
prim
::
kPrimEnvGetItem
,
{
IsCNode
,
IsValueNode
<
SymbolicKeyInstance
>
,
IsNode
})(
node
);
if
(
!
is_match_
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
// {prim::kPrimEnvGetItem, {...}, C, Y}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
inp1
=
cnode
->
input
(
1
)
->
cast
<
CNodePtr
>
();
auto
key
=
GetValueNode
<
SymbolicKeyInstancePtr
>
(
cnode
->
input
(
2
));
auto
default_v
=
cnode
->
input
(
3
);
// {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}
auto
&
inputs_outer
=
inp1
->
inputs
();
if
(
!
inputs_outer
[
0
]
->
isa
<
CNode
>
())
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
args_outer
;
args_outer
.
insert
(
args_outer
.
end
(),
inputs_outer
.
begin
()
+
1
,
inputs_outer
.
end
());
auto
&
input_switch_layer
=
inputs_outer
[
0
]
->
cast
<
CNodePtr
>
()
->
inputs
();
is_match_
=
false
;
AnfVisitor
::
Match
(
prim
::
kPrimSwitchLayer
,
{
IsNode
,
IsCNode
})(
input_switch_layer
[
0
]);
if
(
!
is_match_
)
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
args
;
(
void
)
args
.
insert
(
args
.
end
(),
input_switch_layer
.
begin
()
+
1
,
input_switch_layer
.
end
());
// {prim::kPrimSwitchLayers, X, {prim::kPrimMakeTuple, G1, G2...}}
auto
sw
=
input_switch_layer
[
0
]
->
cast
<
CNodePtr
>
();
std
::
vector
<
FuncGraphPtr
>
graphs
{};
auto
graphs_cnode
=
sw
->
input
(
2
)
->
cast
<
CNodePtr
>
();
auto
&
graphs_inputs
=
graphs_cnode
->
inputs
();
if
(
IsPrimitiveCNode
(
graphs_cnode
,
prim
::
kPrimMakeTuple
)
&&
IsValueNode
<
FuncGraph
>
(
graphs_inputs
[
1
]))
{
(
void
)
std
::
transform
(
graphs_inputs
.
begin
()
+
1
,
graphs_inputs
.
end
(),
std
::
back_inserter
(
graphs
),
[](
const
AnfNodePtr
&
vnode
)
{
return
GetValueNode
<
FuncGraphPtr
>
(
vnode
);
});
}
if
(
graphs
.
empty
())
{
return
nullptr
;
}
auto
fg
=
node
->
func_graph
();
std
::
vector
<
AnfNodePtr
>
layers
;
for
(
auto
&
graph
:
graphs
)
{
auto
fg_transform
=
env_get_item_transform_
(
graph
,
key
,
default_v
);
if
(
fg_transform
==
nullptr
)
{
return
nullptr
;
}
layers
.
push_back
(
NewValueNode
(
fg_transform
));
}
auto
layers_node
=
fg
->
NewCNode
(
prim
::
kPrimMakeTuple
,
layers
);
auto
new_sw
=
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimSwitchLayer
),
sw
->
input
(
1
),
layers_node
});
args
.
insert
(
args
.
begin
(),
new_sw
);
auto
inner_call
=
fg
->
NewCNode
(
args
);
args_outer
.
insert
(
args_outer
.
begin
(),
inner_call
);
return
fg
->
NewCNode
(
args_outer
);
}
void
Visit
(
const
AnfNodePtr
&
)
override
{
is_match_
=
true
;
}
private:
bool
is_match_
{
false
};
internal
::
EnvGetitemTransformACrossGraph
env_get_item_transform_
;
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
浏览文件 @
22a9d02e
...
...
@@ -72,6 +72,52 @@ class GetitemTransform {
private:
std
::
unordered_map
<
FuncGraphPtr
,
std
::
unordered_map
<
int
,
FuncGraphPtr
>>
cache_
;
};
class
GetItemTransformACrossGraph
{
public:
GetItemTransformACrossGraph
()
:
cache_
()
{}
~
GetItemTransformACrossGraph
()
=
default
;
FuncGraphPtr
operator
()(
const
FuncGraphPtr
&
fg
,
int
idx
)
{
if
(
cache_
.
find
(
fg
)
==
cache_
.
end
())
{
cache_
[
fg
]
=
{};
}
auto
&
cache
=
cache_
[
fg
];
if
(
cache
.
find
(
idx
)
==
cache
.
end
())
{
std
::
ostringstream
ss
(
"tp"
,
std
::
ostringstream
::
app
);
ss
<<
idx
;
auto
new_fg_outer
=
TransformableClone
(
fg
,
std
::
make_shared
<
TraceTransform
>
(
ss
.
str
()));
auto
output_outer
=
new_fg_outer
->
output
();
if
(
!
IsValueNode
<
FuncGraph
>
(
output_outer
))
{
MS_LOG
(
WARNING
)
<<
"Output of outer graph should be a func_graph"
;
return
nullptr
;
}
auto
fg_inner
=
GetValueNode
<
FuncGraphPtr
>
(
output_outer
);
auto
new_fg
=
TransformableClone
(
fg_inner
,
std
::
make_shared
<
TraceTransform
>
(
ss
.
str
()));
new_fg_outer
->
set_output
(
NewValueNode
(
new_fg
));
auto
output
=
new_fg
->
output
();
if
(
IsPrimitiveCNode
(
output
,
prim
::
kPrimMakeTuple
))
{
auto
cnode
=
output
->
cast
<
CNodePtr
>
();
auto
ids
=
IntToSize
(
idx
+
1
);
// Inputs should be [make_tuple, item1, item2, ...], so have to offset idx in tuple_getitem by 1.
if
(
ids
>=
cnode
->
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"index "
<<
ids
<<
" is out of inputs length "
<<
cnode
->
size
();
}
new_fg
->
set_output
(
cnode
->
input
(
ids
));
}
else
{
new_fg
->
set_output
(
new_fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
output
,
NewValueNode
(
idx
)}));
}
cache
[
idx
]
=
new_fg_outer
;
}
return
cache
[
idx
];
}
private:
std
::
unordered_map
<
FuncGraphPtr
,
std
::
unordered_map
<
int
,
FuncGraphPtr
>>
cache_
;
};
}
// namespace internal
// {prim::kPrimTupleGetItem, {G, Xs}, C}
...
...
@@ -385,13 +431,199 @@ class IncorporateGetitemSwitch : public AnfVisitor {
internal
::
GetitemTransform
getitem_transform_
;
};
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, C}
class
IncorporateGetitemSwitchLayerA
:
public
AnfVisitor
{
public:
IncorporateGetitemSwitchLayerA
()
:
getitem_transform_
()
{}
~
IncorporateGetitemSwitchLayerA
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
Reset
();
is_in_get_
=
true
;
AnfVisitor
::
Match
(
prim
::
kPrimTupleGetItem
,
{
IsCNode
,
IsValueNode
<
Int32Imm
>
})(
node
);
is_in_get_
=
false
;
auto
fg
=
node
->
func_graph
();
if
(
idx_
==
-
1
||
switch_layer_
==
nullptr
||
fg
==
nullptr
)
{
return
nullptr
;
}
is_in_switch_
=
true
;
AnfVisitor
::
Match
(
prim
::
kPrimSwitchLayer
,
{
IsNode
,
IsCNode
})(
switch_layer_
);
is_in_switch_
=
false
;
if
(
graphs_
.
empty
())
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
layers
;
for
(
auto
&
graph
:
graphs_
)
{
auto
fg_transform
=
getitem_transform_
(
graph
,
idx_
);
if
(
fg_transform
==
nullptr
)
{
return
nullptr
;
}
layers
.
push_back
(
NewValueNode
(
fg_transform
));
}
auto
layers_node
=
fg
->
NewCNode
(
prim
::
kPrimMakeTuple
,
layers
);
std
::
vector
<
AnfNodePtr
>
sw_args
{
NewValueNode
(
prim
::
kPrimSwitchLayer
),
x_
,
layers_node
};
auto
sw_node
=
fg
->
NewCNode
(
sw_args
);
(
void
)
args_
.
insert
(
args_
.
begin
(),
sw_node
);
return
fg
->
NewCNode
(
args_
);
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
is_in_switch_
&&
x_
==
nullptr
)
{
x_
=
node
;
return
;
}
AnfVisitor
::
Visit
(
node
);
}
void
Visit
(
const
CNodePtr
&
cnode
)
override
{
if
(
is_in_get_
&&
cnode
->
size
()
!=
0
)
{
auto
&
inputs
=
cnode
->
inputs
();
switch_layer_
=
inputs
[
0
];
(
void
)
std
::
copy
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
args_
));
}
if
(
is_in_switch_
&&
cnode
->
size
()
>
2
)
{
auto
&
inputs
=
cnode
->
inputs
();
if
(
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimMakeTuple
)
&&
IsValueNode
<
FuncGraph
>
(
inputs
[
1
]))
{
(
void
)
std
::
transform
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
graphs_
),
[](
const
AnfNodePtr
&
vnode
)
{
return
GetValueNode
<
FuncGraphPtr
>
(
vnode
);
});
}
}
}
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
if
(
is_in_get_
)
{
idx_
=
GetValue
<
int
>
(
vnode
->
value
());
}
}
void
Reset
()
{
x_
=
nullptr
;
graphs_
.
clear
();
switch_layer_
=
nullptr
;
args_
.
clear
();
is_in_get_
=
false
;
is_in_switch_
=
false
;
}
private:
int
idx_
{
-
1
};
AnfNodePtr
switch_layer_
{
nullptr
},
x_
{
nullptr
};
std
::
vector
<
FuncGraphPtr
>
graphs_
{};
bool
is_in_get_
{
false
},
is_in_switch_
{
false
};
std
::
vector
<
AnfNodePtr
>
args_
{};
internal
::
GetitemTransform
getitem_transform_
;
};
// {prim::kPrimTupleGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C}
class
IncorporateGetitemSwitchLayerB
:
public
AnfVisitor
{
public:
IncorporateGetitemSwitchLayerB
()
:
getitem_transform_
()
{}
~
IncorporateGetitemSwitchLayerB
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
Reset
();
is_in_get_
=
true
;
AnfVisitor
::
Match
(
prim
::
kPrimTupleGetItem
,
{
IsCNode
,
IsValueNode
<
Int32Imm
>
})(
node
);
is_in_get_
=
false
;
auto
fg
=
node
->
func_graph
();
if
(
idx_
==
-
1
||
switch_layer_call_
==
nullptr
||
!
switch_layer_call_
->
isa
<
CNode
>
()
||
fg
==
nullptr
)
{
return
nullptr
;
}
auto
&
switch_layer_call_inputs
=
switch_layer_call_
->
cast
<
CNodePtr
>
()
->
inputs
();
(
void
)
std
::
copy
(
switch_layer_call_inputs
.
begin
()
+
1
,
switch_layer_call_inputs
.
end
(),
std
::
back_inserter
(
args_
));
is_in_switch_
=
true
;
AnfVisitor
::
Match
(
prim
::
kPrimSwitchLayer
,
{
IsNode
,
IsCNode
})(
switch_layer_call_inputs
[
0
]);
is_in_switch_
=
false
;
if
(
graphs_
.
empty
())
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
layers
;
for
(
auto
&
graph
:
graphs_
)
{
auto
fg_transform
=
getitem_transform_
(
graph
,
idx_
);
if
(
fg_transform
==
nullptr
)
{
return
nullptr
;
}
layers
.
push_back
(
NewValueNode
(
fg_transform
));
}
auto
layers_node
=
fg
->
NewCNode
(
prim
::
kPrimMakeTuple
,
layers
);
std
::
vector
<
AnfNodePtr
>
sw_args
{
NewValueNode
(
prim
::
kPrimSwitchLayer
),
x_
,
layers_node
};
auto
sw_node
=
fg
->
NewCNode
(
sw_args
);
(
void
)
args_
.
insert
(
args_
.
begin
(),
sw_node
);
auto
call_switch_layer
=
fg
->
NewCNode
(
args_
);
(
void
)
outer_call_args_
.
insert
(
outer_call_args_
.
begin
(),
call_switch_layer
);
return
fg
->
NewCNode
(
outer_call_args_
);
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
is_in_switch_
&&
x_
==
nullptr
)
{
x_
=
node
;
return
;
}
AnfVisitor
::
Visit
(
node
);
}
void
Visit
(
const
CNodePtr
&
cnode
)
override
{
if
(
is_in_get_
&&
cnode
->
size
()
!=
0
)
{
auto
&
inputs
=
cnode
->
inputs
();
switch_layer_call_
=
inputs
[
0
];
(
void
)
std
::
copy
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
outer_call_args_
));
}
if
(
is_in_switch_
&&
cnode
->
size
()
>
2
)
{
auto
&
inputs
=
cnode
->
inputs
();
if
(
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimMakeTuple
)
&&
IsValueNode
<
FuncGraph
>
(
inputs
[
1
]))
{
(
void
)
std
::
transform
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
graphs_
),
[](
const
AnfNodePtr
&
vnode
)
{
return
GetValueNode
<
FuncGraphPtr
>
(
vnode
);
});
}
}
}
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
if
(
is_in_get_
)
{
idx_
=
GetValue
<
int
>
(
vnode
->
value
());
}
}
void
Reset
()
{
x_
=
nullptr
;
graphs_
.
clear
();
switch_layer_call_
=
nullptr
;
args_
.
clear
();
outer_call_args_
.
clear
();
is_in_get_
=
false
;
is_in_switch_
=
false
;
}
private:
int
idx_
{
-
1
};
AnfNodePtr
switch_layer_call_
{
nullptr
},
x_
{
nullptr
};
std
::
vector
<
FuncGraphPtr
>
graphs_
{};
bool
is_in_get_
{
false
},
is_in_switch_
{
false
};
std
::
vector
<
AnfNodePtr
>
args_
{};
std
::
vector
<
AnfNodePtr
>
outer_call_args_
{};
internal
::
GetItemTransformACrossGraph
getitem_transform_
;
};
class
IncorporateGetitemSet
:
public
OptimizerCaller
{
public:
IncorporateGetitemSet
()
:
incorporate_getitem_
(
std
::
make_shared
<
IncorporateGetitem
>
()),
incorporate_getitem_switch_
(
std
::
make_shared
<
IncorporateGetitemSwitch
>
())
{
incorporate_getitem_switch_
(
std
::
make_shared
<
IncorporateGetitemSwitch
>
()),
incorporate_getitem_switch_layer_a_
(
std
::
make_shared
<
IncorporateGetitemSwitchLayerA
>
()),
incorporate_getitem_switch_layer_b_
(
std
::
make_shared
<
IncorporateGetitemSwitchLayerB
>
())
{
eliminaters_
.
emplace_back
(
incorporate_getitem_
);
eliminaters_
.
emplace_back
(
incorporate_getitem_switch_
);
eliminaters_
.
emplace_back
(
incorporate_getitem_switch_layer_a_
);
eliminaters_
.
emplace_back
(
incorporate_getitem_switch_layer_b_
);
}
~
IncorporateGetitemSet
()
=
default
;
...
...
@@ -407,7 +639,8 @@ class IncorporateGetitemSet : public OptimizerCaller {
}
private:
OptimizerCallerPtr
incorporate_getitem_
,
incorporate_getitem_switch_
;
OptimizerCallerPtr
incorporate_getitem_
,
incorporate_getitem_switch_
,
incorporate_getitem_switch_layer_a_
,
incorporate_getitem_switch_layer_b_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
};
}
// namespace irpass
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
22a9d02e
...
...
@@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
{
irpass
.
zero_like_fill_zero_
,
irpass
.
item_tuple_eliminate_
,
irpass
.
float_tuple_getitem_switch_
,
irpass
.
reset_defer_inline_
,
irpass
.
inline_
,
irpass
.
special_op_eliminate_
,
irpass
.
get_make_ref_eliminate_
,
irpass
.
incorporate_env_getitem_
,
irpass
.
incorporate_env_getitem_switch_
,
irpass
.
env_get_item_eliminate_
,
irpass
.
value_based_eliminate_
});
irpass
.
incorporate_env_getitem_switch_layer_
,
irpass
.
value_based_eliminate_
});
opt
::
OptPassConfig
b_2
=
opt
::
OptPassConfig
({
irpass
.
replace_refkey_by_param_
,
irpass
.
make_ref_eliminate_
,
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
22a9d02e
...
...
@@ -464,6 +464,36 @@ def test_switch_layer_with_single_prim():
C
.
grad_all
(
net
)(
index
,
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)))
def
test_switch_layer_env_eliminate
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
1
,
1
,
3
,
pad_mode
=
'same'
)
self
.
conv2
=
nn
.
Conv2d
(
1
,
1
,
5
,
pad_mode
=
'same'
)
self
.
funs
=
(
self
.
conv
,
self
.
conv2
)
def
construct
(
self
,
x
,
index
):
x
=
self
.
funs
[
index
](
x
)
return
x
class
NetGrad
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
NetGrad
,
self
).
__init__
()
self
.
grad_op
=
C
.
GradOperation
(
'grad'
,
get_by_list
=
True
,
sens_param
=
False
)
self
.
net
=
net
self
.
weights
=
ParameterTuple
(
self
.
net
.
trainable_params
())
def
construct
(
self
,
x
,
index
):
weights
=
self
.
weights
grad
=
self
.
grad_op
(
self
.
net
,
weights
)(
x
,
index
)
return
grad
net
=
Net
()
net2
=
NetGrad
(
net
)
x
=
Tensor
(
np
.
ones
((
3
,
1
,
12
,
12
)),
ms
.
float32
)
i
=
Tensor
(
1
,
ms
.
int32
)
net2
(
x
,
i
)
def
test_control_depend_check
():
with
pytest
.
raises
(
TypeError
)
as
e
:
P
.
ControlDepend
(
0.0
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录