Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0c1a5d87
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0c1a5d87
编写于
9月 11, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"debug version"
上级
ca973139
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
63 addition
and
19 deletion
+63
-19
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
...paddle/fluid/transpiler/memory_optimization_transpiler.py
+63
-19
未找到文件。
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
浏览文件 @
0c1a5d87
...
@@ -77,6 +77,9 @@ class ControlFlowGraph(object):
...
@@ -77,6 +77,9 @@ class ControlFlowGraph(object):
for
i
in
range
(
self
.
op_size
):
for
i
in
range
(
self
.
op_size
):
self
.
_uses
[
i
].
update
(
self
.
_ops
[
i
].
input_arg_names
())
self
.
_uses
[
i
].
update
(
self
.
_ops
[
i
].
input_arg_names
())
self
.
_defs
[
i
].
update
(
self
.
_ops
[
i
].
output_arg_names
())
self
.
_defs
[
i
].
update
(
self
.
_ops
[
i
].
output_arg_names
())
self
.
_live_in
[
i
]
=
self
.
_uses
[
i
]
# print(self._successors)
# print(self._presuccessors)
def
_update_graph
(
self
,
old_name
,
new_name
,
begin_idx
=
0
):
def
_update_graph
(
self
,
old_name
,
new_name
,
begin_idx
=
0
):
for
i
in
range
(
begin_idx
,
self
.
op_size
):
for
i
in
range
(
begin_idx
,
self
.
op_size
):
...
@@ -86,12 +89,18 @@ class ControlFlowGraph(object):
...
@@ -86,12 +89,18 @@ class ControlFlowGraph(object):
if
old_name
in
self
.
_defs
[
i
]:
if
old_name
in
self
.
_defs
[
i
]:
self
.
_defs
[
i
].
remove
(
old_name
)
self
.
_defs
[
i
].
remove
(
old_name
)
self
.
_defs
[
i
].
add
(
new_name
)
self
.
_defs
[
i
].
add
(
new_name
)
# for i in range(begin_idx, -1, -1):
if
old_name
in
self
.
_live_in
[
i
]:
if
old_name
in
self
.
_live_in
[
i
]:
self
.
_live_in
[
i
].
remove
(
old_name
)
self
.
_live_in
[
i
].
remove
(
old_name
)
self
.
_live_out
[
i
].
add
(
new_name
)
self
.
_live_in
[
i
].
add
(
new_name
)
# if old_name == "concat_3.tmp_0@GRAD":
# print("new_name", new_name)
# print("live_in ", i , self._live_in[i])
if
old_name
in
self
.
_live_out
[
i
]:
if
old_name
in
self
.
_live_out
[
i
]:
self
.
_live_out
[
i
].
remove
(
old_name
)
self
.
_live_out
[
i
].
remove
(
old_name
)
self
.
_live_out
[
i
].
add
(
new_name
)
self
.
_live_out
[
i
].
add
(
new_name
)
# if old_name == "concat_3.tmp_0@GRAD":
# print("live_out ", i , self._live_out[i])
def
_reach_fixed_point
(
self
,
live_in
,
live_out
):
def
_reach_fixed_point
(
self
,
live_in
,
live_out
):
"""Check if the liveness set has stablized."""
"""Check if the liveness set has stablized."""
...
@@ -105,22 +114,40 @@ class ControlFlowGraph(object):
...
@@ -105,22 +114,40 @@ class ControlFlowGraph(object):
return
False
return
False
return
True
return
True
# def _dataflow_analyze(self):
# self._build_graph()
# live_in = defaultdict(set)
# live_out = defaultdict(set)
# # Repeatedly apply liveness updates until the algorithm stablize
# # on a complete set live input vars and live output vars.
# counter = 0
# print(self._successors)
# while True:
# counter += 1
# for i in reversed(list(range(self.op_size))):
# live_in[i] = set(self._live_in[i])
# live_out[i] = set(self._live_out[i])
# for s in self._successors[i]:
# self._live_out[i] |= self._live_in[s]
# self._live_in[i] = self._uses[i] | (
# self._live_out[i] - self._defs[i])
# if self._reach_fixed_point(live_in, live_out):
# break
def
_dataflow_analyze
(
self
):
def
_dataflow_analyze
(
self
):
self
.
_build_graph
()
self
.
_build_graph
()
live_in
=
defaultdict
(
set
)
live_in
=
defaultdict
(
set
)
live_out
=
defaultdict
(
set
)
worklist
=
list
(
range
(
len
(
self
.
_ops
)
-
1
,
-
1
,
-
1
))
# Repeatedly apply liveness updates until the algorithm stablize
while
worklist
:
# on a complete set live input vars and live output vars.
i
=
worklist
.
pop
(
0
)
while
True
:
for
i
in
reversed
(
list
(
range
(
self
.
op_size
))):
live_in
[
i
]
=
set
(
self
.
_live_in
[
i
])
live_in
[
i
]
=
set
(
self
.
_live_in
[
i
])
live_out
[
i
]
=
set
(
self
.
_live_out
[
i
])
for
s
in
self
.
_successors
[
i
]:
for
s
in
self
.
_successors
[
i
]:
self
.
_live_out
[
i
]
|=
self
.
_live_in
[
s
]
self
.
_live_out
[
i
]
|=
self
.
_live_in
[
s
]
self
.
_live_in
[
i
]
=
self
.
_uses
[
i
]
|
(
self
.
_live_in
[
i
]
=
self
.
_uses
[
i
]
|
(
self
.
_live_out
[
i
]
-
self
.
_defs
[
i
])
self
.
_live_out
[
i
]
-
self
.
_defs
[
i
])
if
self
.
_reach_fixed_point
(
live_in
,
live_out
):
if
live_in
[
i
]
!=
self
.
_live_in
[
i
]:
break
for
d
in
self
.
_presuccessors
[
i
]:
worklist
.
append
(
d
)
def
_get_diff
(
self
,
a
,
b
):
def
_get_diff
(
self
,
a
,
b
):
u
=
a
&
b
u
=
a
&
b
...
@@ -218,6 +245,17 @@ class ControlFlowGraph(object):
...
@@ -218,6 +245,17 @@ class ControlFlowGraph(object):
continue
continue
block_desc
=
op
.
block
()
block_desc
=
op
.
block
()
is_forward
=
i
<
self
.
_forward_num
is_forward
=
i
<
self
.
_forward_num
in_diff
,
_
=
self
.
_get_diff
(
self
.
_live_in
[
i
],
self
.
_live_out
[
i
])
can_optimize
=
[
x
for
x
in
in_diff
if
self
.
_check_var_validity
(
block_desc
,
x
,
is_forward
)
]
if
can_optimize
:
for
var_name
in
can_optimize
:
self
.
pool
.
append
((
var_name
,
self
.
_find_var
(
block_desc
,
var_name
,
is_forward
).
shape
()))
# print(op.type(), i, self.pool)
# print(self._live_in[i])
if
self
.
pool
:
if
self
.
pool
:
defs_can_optimize
=
[
defs_can_optimize
=
[
x
for
x
in
self
.
_defs
[
i
]
x
for
x
in
self
.
_defs
[
i
]
...
@@ -249,21 +287,24 @@ class ControlFlowGraph(object):
...
@@ -249,21 +287,24 @@ class ControlFlowGraph(object):
if
x_dtype
!=
cache_dtype
:
if
x_dtype
!=
cache_dtype
:
continue
continue
self
.
pool
.
pop
(
index
)
if
x
==
cache_var
:
break
if
PRINT_LOG
:
if
PRINT_LOG
:
print
((
"Hit Cache !!!! cache pool index "
print
((
"Hit Cache !!!! cache pool index "
"is %d, var name is %s, "
"is %d, var name is %s, "
"cached var name is %s, "
"cached var name is %s, "
"var shape is %s "
)
%
(
index
,
x
,
cache_var
,
"var shape is %s "
)
%
(
index
,
x
,
cache_var
,
str
(
cache_shape
)))
str
(
cache_shape
)))
self
.
pool
.
pop
(
index
)
if
x
==
cache_var
:
break
# Rename the var to the cache var already with
# Rename the var to the cache var already with
# memory allocated in order to reuse the memory.
# memory allocated in order to reuse the memory.
_rename_arg_
(
self
.
_ops
,
x
,
cache_var
,
begin_idx
=
i
)
_rename_arg_
(
self
.
_ops
,
x
,
cache_var
,
begin_idx
=
i
)
self
.
_program
.
block
(
block_desc
.
id
).
var
(
cpt
.
to_text
(
self
.
_program
.
block
(
block_desc
.
id
).
var
(
cpt
.
to_text
(
x
)).
desc
=
self
.
_find_var
(
block_desc
,
cache_var
,
x
)).
desc
=
self
.
_find_var
(
block_desc
,
cache_var
,
is_forward
)
is_forward
)
if
x
==
"concat_3.tmp_0@GRAD"
:
print
(
"Update Graph"
,
i
)
self
.
_update_graph
(
x
,
cache_var
,
begin_idx
=
i
)
self
.
_update_graph
(
x
,
cache_var
,
begin_idx
=
i
)
break
break
...
@@ -272,10 +313,13 @@ class ControlFlowGraph(object):
...
@@ -272,10 +313,13 @@ class ControlFlowGraph(object):
x
for
x
in
in_diff
x
for
x
in
in_diff
if
self
.
_check_var_validity
(
block_desc
,
x
,
is_forward
)
if
self
.
_check_var_validity
(
block_desc
,
x
,
is_forward
)
]
]
keys
=
set
([
key
for
key
,
shape
in
self
.
pool
])
if
can_optimize
:
if
can_optimize
:
for
var_name
in
can_optimize
:
for
var_name
in
can_optimize
:
if
var_name
not
in
keys
:
self
.
pool
.
append
((
var_name
,
self
.
_find_var
(
self
.
pool
.
append
((
var_name
,
self
.
_find_var
(
block_desc
,
var_name
,
is_forward
).
shape
()))
block_desc
,
var_name
,
is_forward
).
shape
()))
# print(op.type(), i, self.pool)
def
_process_sub_block_pair
(
pdesc
,
sub_block_pair
):
def
_process_sub_block_pair
(
pdesc
,
sub_block_pair
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录