Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2617ac9d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2617ac9d
编写于
9月 17, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"add doc string"
上级
bddd4bc0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
20 addition
and
12 deletion
+20
-12
python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py
...id/tests/unittests/test_memory_optimization_transpiler.py
+6
-2
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
...paddle/fluid/transpiler/memory_optimization_transpiler.py
+14
-10
未找到文件。
python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py
浏览文件 @
2617ac9d
...
...
@@ -15,6 +15,7 @@
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
import
paddle.fluid.optimizer
as
optimizer
from
paddle.fluid.framework
import
Program
,
program_guard
...
...
@@ -66,13 +67,16 @@ class TestMemoryTranspiler2(unittest.TestCase):
print
(
"after optimization"
)
print
(
str
(
result_program
))
class
TestMemoryTranspiler3
(
unittest
.
TestCase
):
def
setUp
(
self
):
program
=
Program
()
with
program_guard
(
program
,
startup_program
=
Program
()):
word
=
fluid
.
layers
.
data
(
name
=
'word'
,
shape
=
[
1
],
dtype
=
'int64'
)
emb
=
[
fluid
.
layers
.
embedding
(
word
,
size
=
[
65536
,
256
],
param_attr
=
'emb'
)
for
_
in
range
(
6
)]
emb
=
[
fluid
.
layers
.
embedding
(
word
,
size
=
[
65536
,
256
],
param_attr
=
'emb'
)
for
_
in
range
(
6
)
]
left
=
emb
.
pop
(
0
)
while
len
(
emb
)
!=
0
:
...
...
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
100644 → 100755
浏览文件 @
2617ac9d
...
...
@@ -96,7 +96,6 @@ class ControlFlowGraph(object):
self
.
_live_out
[
i
].
remove
(
old_name
)
self
.
_live_out
[
i
].
add
(
new_name
)
def
_dataflow_analyze
(
self
):
self
.
_build_graph
()
live_in
=
defaultdict
(
set
)
...
...
@@ -121,8 +120,8 @@ class ControlFlowGraph(object):
]
if
can_optimize
:
for
var_name
in
can_optimize
:
cache
=
(
var_name
,
self
.
_find_var
(
block_desc
,
var_name
,
is_forward
).
shape
())
cache
=
(
var_name
,
self
.
_find_var
(
block_desc
,
var_name
,
is_forward
).
shape
())
if
cache
not
in
self
.
pool
:
self
.
pool
.
append
(
cache
)
...
...
@@ -232,7 +231,7 @@ class ControlFlowGraph(object):
]
for
x
,
x_shape
in
out_pair
:
if
(
x
,
x_shape
)
in
self
.
pool
:
raise
ValueError
(
"x in pool
"
)
raise
ValueError
(
"x in pool
, %s, %s"
%
(
x
,
x_shape
)
)
# If x is both in uses and defs, it can not be optimized!
if
x
in
self
.
_uses
[
i
]:
continue
...
...
@@ -240,9 +239,14 @@ class ControlFlowGraph(object):
cache_var
=
cache_pair
[
0
]
cache_shape
=
cache_pair
[
1
]
if
not
self
.
_has_var
(
block_desc
,
cache_var
,
is_forward
):
raise
ValueError
(
"cache"
,
cpt
.
to_text
(
cache_var
),
" Not exists!"
)
raise
ValueError
(
"cache"
,
cpt
.
to_text
(
cache_var
),
" Not exists!"
)
if
x
==
cache_var
:
raise
ValueError
(
"x : "
,
cpt
.
to_text
(
x
),
" cache : "
,
cpt
.
to_text
(
cache_var
),
" is same var!"
)
raise
ValueError
(
"x : "
,
cpt
.
to_text
(
x
),
" cache : "
,
cpt
.
to_text
(
cache_var
),
" is same var!"
)
x_dtype
=
self
.
_find_var
(
block_desc
,
x
,
is_forward
).
dtype
()
...
...
@@ -266,14 +270,14 @@ class ControlFlowGraph(object):
# Rename the var to the cache var already with
# memory allocated in order to reuse the memory.
_rename_arg_
(
self
.
_ops
,
x
,
cache_var
,
begin_idx
=
i
)
self
.
_program
.
block
(
block_desc
.
id
).
_remove_var
(
cpt
.
to_text
(
x
))
self
.
_program
.
block
(
block_desc
.
id
).
var
(
cpt
.
to_text
(
x
)).
desc
=
self
.
_find_var
(
block_desc
,
cache_var
,
is_forward
)
self
.
_update_graph
(
x
,
cache_var
,
begin_idx
=
i
)
break
self
.
_fill_pool
(
i
,
is_forward
)
def
_process_sub_block_pair
(
pdesc
,
sub_block_pair
):
"""Creates a list of tuple each of which tracks info of a subblock.
...
...
@@ -379,7 +383,7 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
Note: it doesn't not support subblock nested in subblock.
:param input_program: Input Program
:param input_program
(str)
: Input Program
:param print_log: whether to print debug log.
:param level: If level=0, reuse if the shape is completely equal, o
:return:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录