Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
db7d8136
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
db7d8136
编写于
8月 09, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix CI issue
上级
ee1d08ab
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
16 addition
and
16 deletion
+16
-16
python/paddle/dataset/imdb.py
python/paddle/dataset/imdb.py
+6
-6
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+3
-4
python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
...ry_optimization/test_memopt_image_classification_train.py
+1
-1
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
...paddle/fluid/transpiler/memory_optimization_transpiler.py
+6
-5
未找到文件。
python/paddle/dataset/imdb.py
浏览文件 @
db7d8136
...
...
@@ -25,7 +25,7 @@ import collections
import
tarfile
import
re
import
string
from
six.moves
import
range
import
six
__all__
=
[
'build_dict'
,
'train'
,
'test'
,
'convert'
]
...
...
@@ -43,13 +43,13 @@ def tokenize(pattern):
# sequential access of member files, other than
# tarfile.extractfile, which does random access and might
# destroy hard disks.
tf
=
next
(
tarf
)
tf
=
tarf
.
next
(
)
while
tf
!=
None
:
if
bool
(
pattern
.
match
(
tf
.
name
)):
# newline and punctuations removal and ad-hoc tokenization.
yield
tarf
.
extractfile
(
tf
).
read
().
rstrip
(
"
\n\r
"
).
translate
(
None
,
s
tring
.
punctuation
).
lower
().
split
()
tf
=
next
(
tarf
)
yield
tarf
.
extractfile
(
tf
).
read
().
rstrip
(
six
.
b
(
"
\n\r
"
)
).
translate
(
None
,
s
ix
.
b
(
string
.
punctuation
)
).
lower
().
split
()
tf
=
tarf
.
next
(
)
def
build_dict
(
pattern
,
cutoff
):
...
...
@@ -67,7 +67,7 @@ def build_dict(pattern, cutoff):
dictionary
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
dictionary
))
word_idx
=
dict
(
list
(
zip
(
words
,
range
(
len
(
words
)))))
word_idx
=
dict
(
list
(
zip
(
words
,
six
.
moves
.
range
(
len
(
words
)))))
word_idx
[
'<unk>'
]
=
len
(
words
)
return
word_idx
...
...
python/paddle/fluid/framework.py
浏览文件 @
db7d8136
...
...
@@ -905,10 +905,9 @@ class Block(object):
Variable: the Variable with the giving name.
"""
if
not
isinstance
(
name
,
six
.
string_types
):
if
not
isinstance
(
name
,
six
.
binary_type
):
raise
TypeError
(
"var require string as parameter, but get %s instead."
%
(
type
(
name
)))
raise
TypeError
(
"var require string as parameter, but get %s instead."
%
(
type
(
name
)))
v
=
self
.
vars
.
get
(
name
,
None
)
if
v
is
None
:
raise
ValueError
(
"var %s not in this block"
%
name
)
...
...
python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
浏览文件 @
db7d8136
...
...
@@ -56,7 +56,7 @@ def resnet_cifar10(input, depth=32):
return
tmp
assert
(
depth
-
2
)
%
6
==
0
n
=
(
depth
-
2
)
/
6
n
=
(
depth
-
2
)
/
/
6
conv1
=
conv_bn_layer
(
input
=
input
,
ch_out
=
16
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
res1
=
layer_warp
(
basicblock
,
conv1
,
16
,
16
,
n
,
1
)
...
...
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
浏览文件 @
db7d8136
...
...
@@ -14,6 +14,7 @@
from
collections
import
defaultdict
from
..
import
core
from
..
import
compat
from
..framework
import
Program
,
default_main_program
,
Parameter
from
..backward
import
_rename_arg_
from
functools
import
reduce
...
...
@@ -125,15 +126,15 @@ class ControlFlowGraph(object):
def
_has_var
(
self
,
block_desc
,
var_name
,
is_forward
):
if
is_forward
:
return
block_desc
.
has_var
(
str
(
var_name
))
return
block_desc
.
has_var
(
cpt
.
to_bytes
(
var_name
))
else
:
return
block_desc
.
has_var_recursive
(
str
(
var_name
))
return
block_desc
.
has_var_recursive
(
cpt
.
to_bytes
(
var_name
))
def
_find_var
(
self
,
block_desc
,
var_name
,
is_forward
):
if
is_forward
:
return
block_desc
.
find_var
(
str
(
var_name
))
return
block_desc
.
find_var
(
cpt
.
to_bytes
(
var_name
))
else
:
return
block_desc
.
find_var_recursive
(
str
(
var_name
))
return
block_desc
.
find_var_recursive
(
cpt
.
to_bytes
(
var_name
))
def
_check_var_validity
(
self
,
block_desc
,
x
,
is_forward
):
if
str
(
x
)
==
"@EMPTY@"
:
...
...
@@ -258,7 +259,7 @@ class ControlFlowGraph(object):
# Rename the var to the cache var already with
# memory allocated in order to reuse the memory.
_rename_arg_
(
self
.
_ops
,
x
,
cache_var
,
begin_idx
=
i
)
self
.
_program
.
block
(
block_desc
.
id
).
var
(
str
(
self
.
_program
.
block
(
block_desc
.
id
).
var
(
cpt
.
to_literal_
str
(
x
)).
desc
=
self
.
_find_var
(
block_desc
,
cache_var
,
is_forward
)
self
.
_update_graph
(
x
,
cache_var
,
begin_idx
=
i
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录