Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
ef60a654
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ef60a654
编写于
9月 14, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"add test"
上级
0c1a5d87
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
65 addition
and
75 deletion
+65
-75
python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py
...id/tests/unittests/test_memory_optimization_transpiler.py
+25
-0
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
...paddle/fluid/transpiler/memory_optimization_transpiler.py
+40
-75
未找到文件。
python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py
浏览文件 @
ef60a654
...
...
@@ -66,6 +66,31 @@ class TestMemoryTranspiler2(unittest.TestCase):
print
(
"after optimization"
)
print
(
str
(
result_program
))
class
TestMemoryTranspiler3
(
unittest
.
TestCase
):
def
setUp
(
self
):
program
=
Program
()
with
program_guard
(
program
,
startup_program
=
Program
()):
word
=
fluid
.
layers
.
data
(
name
=
'word'
,
shape
=
[
1
],
dtype
=
'int64'
)
emb
=
[
fluid
.
layers
.
embedding
(
word
,
size
=
[
65536
,
256
],
param_attr
=
'emb'
)
for
_
in
range
(
6
)]
left
=
emb
.
pop
(
0
)
while
len
(
emb
)
!=
0
:
right
=
emb
.
pop
(
0
)
left
=
fluid
.
layers
.
concat
([
left
,
right
])
emb
=
fluid
.
layers
.
mean
(
left
)
fluid
.
backward
.
append_backward
(
emb
)
self
.
program
=
program
def
test_cascade_reuse
(
self
):
block
=
self
.
program
.
block
(
0
)
# variable reuse in programdesc
self
.
assertTrue
(
"concat_4.tmp_0@GRAD"
in
block
.
vars
)
self
.
assertTrue
(
"concat_3.tmp_0@GRAD"
not
in
block
.
vars
)
self
.
assertTrue
(
"concat_2.tmp_0@GRAD"
not
in
block
.
vars
)
self
.
assertTrue
(
"concat_1.tmp_0@GRAD"
not
in
block
.
vars
)
self
.
assertTrue
(
"concat_0.tmp_0@GRAD"
not
in
block
.
vars
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
浏览文件 @
ef60a654
...
...
@@ -47,6 +47,7 @@ PRINT_LOG = False
class
ControlFlowGraph
(
object
):
def
__init__
(
self
,
program
,
ops
,
forward_num
,
skip_opt
):
self
.
_program
=
program
self
.
_dup_program
=
program
.
clone
()
self
.
_ops
=
ops
self
.
_forward_num
=
forward_num
self
.
_successors
=
defaultdict
(
set
)
...
...
@@ -56,6 +57,7 @@ class ControlFlowGraph(object):
self
.
_live_in
=
defaultdict
(
set
)
self
.
_live_out
=
defaultdict
(
set
)
self
.
_skip_opt
=
skip_opt
self
.
pool
=
[]
def
_add_connections
(
self
,
connections
):
"""Populates _successors and _presuccessors for two neighbor nodes."""
...
...
@@ -78,8 +80,6 @@ class ControlFlowGraph(object):
self
.
_uses
[
i
].
update
(
self
.
_ops
[
i
].
input_arg_names
())
self
.
_defs
[
i
].
update
(
self
.
_ops
[
i
].
output_arg_names
())
self
.
_live_in
[
i
]
=
self
.
_uses
[
i
]
# print(self._successors)
# print(self._presuccessors)
def
_update_graph
(
self
,
old_name
,
new_name
,
begin_idx
=
0
):
for
i
in
range
(
begin_idx
,
self
.
op_size
):
...
...
@@ -89,50 +89,13 @@ class ControlFlowGraph(object):
if
old_name
in
self
.
_defs
[
i
]:
self
.
_defs
[
i
].
remove
(
old_name
)
self
.
_defs
[
i
].
add
(
new_name
)
# for i in range(begin_idx, -1, -1):
if
old_name
in
self
.
_live_in
[
i
]:
self
.
_live_in
[
i
].
remove
(
old_name
)
self
.
_live_in
[
i
].
add
(
new_name
)
# if old_name == "concat_3.tmp_0@GRAD":
# print("new_name", new_name)
# print("live_in ", i , self._live_in[i])
if
old_name
in
self
.
_live_out
[
i
]:
self
.
_live_out
[
i
].
remove
(
old_name
)
self
.
_live_out
[
i
].
add
(
new_name
)
# if old_name == "concat_3.tmp_0@GRAD":
# print("live_out ", i , self._live_out[i])
def
_reach_fixed_point
(
self
,
live_in
,
live_out
):
"""Check if the liveness set has stablized."""
if
len
(
live_in
)
!=
len
(
self
.
_live_in
):
return
False
if
len
(
live_out
)
!=
len
(
self
.
_live_out
):
return
False
for
i
in
range
(
self
.
op_size
):
if
(
live_in
[
i
]
!=
self
.
_live_in
[
i
]
or
live_out
[
i
]
!=
self
.
_live_out
[
i
]):
return
False
return
True
# def _dataflow_analyze(self):
# self._build_graph()
# live_in = defaultdict(set)
# live_out = defaultdict(set)
# # Repeatedly apply liveness updates until the algorithm stablize
# # on a complete set live input vars and live output vars.
# counter = 0
# print(self._successors)
# while True:
# counter += 1
# for i in reversed(list(range(self.op_size))):
# live_in[i] = set(self._live_in[i])
# live_out[i] = set(self._live_out[i])
# for s in self._successors[i]:
# self._live_out[i] |= self._live_in[s]
# self._live_in[i] = self._uses[i] | (
# self._live_out[i] - self._defs[i])
# if self._reach_fixed_point(live_in, live_out):
# break
def
_dataflow_analyze
(
self
):
self
.
_build_graph
()
...
...
@@ -149,6 +112,20 @@ class ControlFlowGraph(object):
for
d
in
self
.
_presuccessors
[
i
]:
worklist
.
append
(
d
)
def
_fill_pool
(
self
,
i
,
is_forward
):
block_desc
=
self
.
_ops
[
i
].
block
()
in_diff
,
_
=
self
.
_get_diff
(
self
.
_live_in
[
i
],
self
.
_live_out
[
i
])
can_optimize
=
[
x
for
x
in
in_diff
if
self
.
_check_var_validity
(
block_desc
,
x
,
is_forward
)
]
if
can_optimize
:
for
var_name
in
can_optimize
:
cache
=
(
var_name
,
self
.
_find_var
(
block_desc
,
var_name
,
is_forward
).
shape
())
if
cache
not
in
self
.
pool
:
self
.
pool
.
append
(
cache
)
def
_get_diff
(
self
,
a
,
b
):
u
=
a
&
b
return
a
-
u
,
b
-
u
...
...
@@ -238,24 +215,15 @@ class ControlFlowGraph(object):
# update skip set to meet users' demand
if
skip_opt_set
:
self
.
_skip_opt
.
update
(
skip_opt_set
)
self
.
pool
=
[]
#
self.pool = []
for
i
in
range
(
self
.
op_size
):
op
=
self
.
_ops
[
i
]
if
op
.
type
()
in
SUB_BLOCK_OPS
:
continue
block_desc
=
op
.
block
()
is_forward
=
i
<
self
.
_forward_num
in_diff
,
_
=
self
.
_get_diff
(
self
.
_live_in
[
i
],
self
.
_live_out
[
i
])
can_optimize
=
[
x
for
x
in
in_diff
if
self
.
_check_var_validity
(
block_desc
,
x
,
is_forward
)
]
if
can_optimize
:
for
var_name
in
can_optimize
:
self
.
pool
.
append
((
var_name
,
self
.
_find_var
(
block_desc
,
var_name
,
is_forward
).
shape
()))
self
.
_fill_pool
(
i
,
is_forward
)
# print(op.type(), i, self.pool)
# print(self._live_in[i])
if
self
.
pool
:
defs_can_optimize
=
[
x
for
x
in
self
.
_defs
[
i
]
...
...
@@ -266,60 +234,57 @@ class ControlFlowGraph(object):
for
x
in
defs_can_optimize
]
for
x
,
x_shape
in
out_pair
:
if
(
x
,
x_shape
)
in
self
.
pool
:
raise
ValueError
(
"x in pool"
)
# If x is both in uses and defs, it can not be optimized!
if
x
in
self
.
_uses
[
i
]:
# print(self.pool, op.type(), cpt.to_text(x))
# raise ValueError("x in use!", cpt.to_text(x))
continue
for
index
,
cache_pair
in
enumerate
(
self
.
pool
):
cache_var
=
cache_pair
[
0
]
cache_shape
=
cache_pair
[
1
]
if
not
compare_shape
(
x_shape
,
cache_shape
,
level
):
continue
if
not
self
.
_has_var
(
block_desc
,
cache_var
,
is_forward
):
continue
raise
ValueError
(
"cache"
,
cpt
.
to_text
(
cache_var
),
" Not exists!"
)
if
x
==
cache_var
:
raise
ValueError
(
"x : "
,
cpt
.
to_text
(
x
),
" cache : "
,
cpt
.
to_text
(
cache_var
),
" is same var!"
)
x_dtype
=
self
.
_find_var
(
block_desc
,
x
,
is_forward
).
dtype
()
cache_dtype
=
self
.
_find_var
(
block_desc
,
cache_var
,
is_forward
).
dtype
()
if
not
compare_shape
(
x_shape
,
cache_shape
,
level
):
continue
# TODO(qijun): actually, we should compare
# dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
if
x_dtype
!=
cache_dtype
:
continue
self
.
pool
.
pop
(
index
)
if
x
==
cache_var
:
break
if
PRINT_LOG
:
print
((
"Hit Cache !!!! cache pool index "
"is %d, var name is %s, "
"cached var name is %s, "
"var shape is %s "
)
%
(
index
,
x
,
cache_var
,
str
(
cache_shape
)))
self
.
pool
.
pop
(
index
)
# Rename the var to the cache var already with
# memory allocated in order to reuse the memory.
_rename_arg_
(
self
.
_ops
,
x
,
cache_var
,
begin_idx
=
i
)
self
.
_program
.
block
(
block_desc
.
id
).
var
(
cpt
.
to_text
(
x
)).
desc
=
self
.
_find_var
(
block_desc
,
cache_var
,
is_forward
)
if
x
==
"concat_3.tmp_0@GRAD"
:
print
(
"Update Graph"
,
i
)
self
.
_program
.
block
(
block_desc
.
id
).
_remove_var
(
cpt
.
to_text
(
x
))
# if str(self._program) != str(self._dup_program):
# with open("./program_middle", "w") as f:
# f.write(str(self._program))
# f.flush()
# exit(0)
# self._program.block(block_desc.id).var(cpt.to_text(
# x)).desc = self._find_var(block_desc, cache_var,
# is_forward)
self
.
_update_graph
(
x
,
cache_var
,
begin_idx
=
i
)
break
# self._fill_pool(i, is_forward)
in_diff
,
_
=
self
.
_get_diff
(
self
.
_live_in
[
i
],
self
.
_live_out
[
i
])
can_optimize
=
[
x
for
x
in
in_diff
if
self
.
_check_var_validity
(
block_desc
,
x
,
is_forward
)
]
keys
=
set
([
key
for
key
,
shape
in
self
.
pool
])
if
can_optimize
:
for
var_name
in
can_optimize
:
if
var_name
not
in
keys
:
self
.
pool
.
append
((
var_name
,
self
.
_find_var
(
block_desc
,
var_name
,
is_forward
).
shape
()))
# print(op.type(), i, self.pool)
def
_process_sub_block_pair
(
pdesc
,
sub_block_pair
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录