Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c8874f23
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c8874f23
编写于
9月 01, 2022
作者:
K
kuizhiqing
提交者:
GitHub
9月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CINN] add fetch and prune for build cinn pass (#45531)
* add fetch and prune for build cinn pass * add prune flag
上级
13d62e12
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
56 addition
and
10 deletion
+56
-10
python/paddle/distributed/passes/cpp_pass.py
python/paddle/distributed/passes/cpp_pass.py
+37
-6
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+19
-4
未找到文件。
python/paddle/distributed/passes/cpp_pass.py
浏览文件 @
c8874f23
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
paddle.static
import
Executor
from
.pass_base
import
PassType
,
CPPPassWrapper
,
register_pass
from
.pass_base
import
PassType
,
CPPPassWrapper
,
register_pass
from
paddle.fluid.framework
import
core
,
_apply_pass
as
_apply_cpp_pass
from
paddle.fluid.framework
import
core
,
_apply_pass
as
_apply_cpp_pass
...
@@ -102,6 +103,13 @@ class InplaceAddtoOpPass(CPPPassWrapper):
...
@@ -102,6 +103,13 @@ class InplaceAddtoOpPass(CPPPassWrapper):
return
PassType
.
CALC_OPT
return
PassType
.
CALC_OPT
def
_set_cinn_op_flag
(
flag_name
,
extra_ops
):
values
=
core
.
globals
()[
flag_name
]
values
=
[
v
.
strip
()
for
v
in
values
.
split
(
";"
)
if
v
.
strip
()]
values
.
extend
(
extra_ops
)
core
.
globals
()[
flag_name
]
=
";"
.
join
(
values
)
@
register_pass
(
"build_cinn"
)
@
register_pass
(
"build_cinn"
)
class
BuildCINNPass
(
CPPPassWrapper
):
class
BuildCINNPass
(
CPPPassWrapper
):
...
@@ -118,18 +126,41 @@ class BuildCINNPass(CPPPassWrapper):
...
@@ -118,18 +126,41 @@ class BuildCINNPass(CPPPassWrapper):
return
PassType
.
CALC_OPT
return
PassType
.
CALC_OPT
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
allow_ops
=
";"
.
join
(
self
.
get_attr
(
"allow_ops"
))
deny_ops
=
";"
.
join
(
self
.
get_attr
(
"deny_ops"
))
assert
'FLAGS_allow_cinn_ops'
in
core
.
globals
(
assert
'FLAGS_allow_cinn_ops'
in
core
.
globals
(
),
"PaddlePaddle is not compiled with CINN support"
),
"PaddlePaddle is not compiled with CINN support"
old_allow_ops
=
core
.
globals
()[
'FLAGS_allow_cinn_ops'
]
old_allow_ops
=
core
.
globals
()[
'FLAGS_allow_cinn_ops'
]
old_deny_ops
=
core
.
globals
()[
'FLAGS_deny_cinn_ops'
]
old_deny_ops
=
core
.
globals
()[
'FLAGS_deny_cinn_ops'
]
try
:
try
:
core
.
globals
()[
'FLAGS_allow_cinn_ops'
]
=
allow_ops
_set_cinn_op_flag
(
'FLAGS_allow_cinn_ops'
,
core
.
globals
()[
'FLAGS_deny_cinn_ops'
]
=
deny_ops
self
.
get_attr
(
"allow_ops"
))
_apply_cpp_pass
(
main_program
,
startup_program
,
self
.
cpp_name
,
{},
_set_cinn_op_flag
(
'FLAGS_deny_cinn_ops'
,
self
.
get_attr
(
"deny_ops"
))
self
.
cpp_attr_types
)
feed
=
self
.
get_attr
(
'feed'
,
[])
fetch_list
=
self
.
get_attr
(
'fetch_list'
,
[])
prune_program
=
self
.
get_attr
(
'prune_program'
,
True
)
if
prune_program
:
tmp_main_program
=
Executor
.
_prune_program
(
main_program
,
feed
,
fetch_list
,
[])
tmp_main_program
=
Executor
.
_add_fetch_ops
(
tmp_main_program
,
fetch_list
,
'fetch'
)
else
:
tmp_main_program
=
Executor
.
_add_fetch_ops
(
main_program
,
fetch_list
,
'fetch'
)
_apply_cpp_pass
(
tmp_main_program
,
startup_program
,
self
.
cpp_name
,
{},
self
.
cpp_attr_types
)
tmp_main_program
=
Executor
.
_remove_fetch_ops
(
tmp_main_program
)
tmp_main_program
=
core
.
ProgramDesc
(
tmp_main_program
.
desc
)
main_program
.
_rebuild_from_desc
(
tmp_main_program
)
finally
:
finally
:
core
.
globals
()[
'FLAGS_allow_cinn_ops'
]
=
old_allow_ops
core
.
globals
()[
'FLAGS_allow_cinn_ops'
]
=
old_allow_ops
core
.
globals
()[
'FLAGS_deny_cinn_ops'
]
=
old_deny_ops
core
.
globals
()[
'FLAGS_deny_cinn_ops'
]
=
old_deny_ops
python/paddle/fluid/executor.py
浏览文件 @
c8874f23
...
@@ -978,7 +978,8 @@ class Executor(object):
...
@@ -978,7 +978,8 @@ class Executor(object):
]
]
return
outs
return
outs
def
_split_optimize_ops_in_fetch_list
(
self
,
fetch_list
):
@
classmethod
def
_split_optimize_ops_in_fetch_list
(
cls
,
fetch_list
):
"""
"""
Split optimize_ops from fetch_list, which provided to specify program prunning.
Split optimize_ops from fetch_list, which provided to specify program prunning.
Args:
Args:
...
@@ -1030,7 +1031,8 @@ class Executor(object):
...
@@ -1030,7 +1031,8 @@ class Executor(object):
return
_fetch_list
,
_optimize_ops
return
_fetch_list
,
_optimize_ops
def
_prune_program
(
self
,
@
classmethod
def
_prune_program
(
cls
,
program
,
program
,
feed
=
None
,
feed
=
None
,
fetch_list
=
None
,
fetch_list
=
None
,
...
@@ -1093,7 +1095,8 @@ class Executor(object):
...
@@ -1093,7 +1095,8 @@ class Executor(object):
return
program
return
program
def
_update_feed
(
self
,
program
,
feed
):
@
classmethod
def
_update_feed
(
cls
,
program
,
feed
):
"""
"""
Update the feed dict, remove the feed item which is pruned in program.
Update the feed dict, remove the feed item which is pruned in program.
...
@@ -2379,7 +2382,8 @@ class Executor(object):
...
@@ -2379,7 +2382,8 @@ class Executor(object):
return
tmp_program
return
tmp_program
def
_add_fetch_ops
(
self
,
@
classmethod
def
_add_fetch_ops
(
cls
,
program
,
program
,
fetch_list
,
fetch_list
,
fetch_var_name
,
fetch_var_name
,
...
@@ -2416,6 +2420,17 @@ class Executor(object):
...
@@ -2416,6 +2420,17 @@ class Executor(object):
return
tmp_program
return
tmp_program
@
classmethod
def
_remove_fetch_ops
(
cls
,
program
,
fetch_op_name
=
'fetch'
):
tmp_program
=
program
.
clone
()
global_block
=
tmp_program
.
global_block
()
op_num
=
len
(
global_block
.
ops
)
for
idx
in
reversed
(
range
(
op_num
)):
if
global_block
.
ops
[
idx
].
type
==
fetch_op_name
:
global_block
.
_remove_op
(
idx
)
return
tmp_program
def
_run_pipeline
(
self
,
def
_run_pipeline
(
self
,
program
=
None
,
program
=
None
,
dataset
=
None
,
dataset
=
None
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录