Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6dc07e7f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6dc07e7f
编写于
8月 10, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace items() with six.moves.iteritems() to improve memory usage
上级
9cd59990
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
57 addition
and
47 deletion
+57
-47
python/paddle/dataset/imdb.py
python/paddle/dataset/imdb.py
+4
-3
python/paddle/dataset/imikolov.py
python/paddle/dataset/imikolov.py
+6
-4
python/paddle/dataset/sentiment.py
python/paddle/dataset/sentiment.py
+2
-1
python/paddle/dataset/wmt14.py
python/paddle/dataset/wmt14.py
+2
-2
python/paddle/dataset/wmt16.py
python/paddle/dataset/wmt16.py
+1
-1
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+5
-5
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+1
-1
python/paddle/fluid/graphviz.py
python/paddle/fluid/graphviz.py
+5
-4
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+2
-1
python/paddle/fluid/metrics.py
python/paddle/fluid/metrics.py
+7
-6
python/paddle/fluid/tests/unittests/benchmark.py
python/paddle/fluid/tests/unittests/benchmark.py
+2
-2
python/paddle/fluid/tests/unittests/test_detection_map_op.py
python/paddle/fluid/tests/unittests/test_detection_map_op.py
+2
-1
python/paddle/fluid/tests/unittests/test_lod_rank_table.py
python/paddle/fluid/tests/unittests/test_lod_rank_table.py
+2
-1
python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py
...e/fluid/tests/unittests/test_positive_negative_pair_op.py
+2
-1
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+3
-2
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+11
-12
未找到文件。
python/paddle/dataset/imdb.py
浏览文件 @
6dc07e7f
...
...
@@ -47,8 +47,9 @@ def tokenize(pattern):
while
tf
!=
None
:
if
bool
(
pattern
.
match
(
tf
.
name
)):
# newline and punctuations removal and ad-hoc tokenization.
yield
tarf
.
extractfile
(
tf
).
read
().
rstrip
(
six
.
b
(
"
\n\r
"
)).
translate
(
None
,
six
.
b
(
string
.
punctuation
)).
lower
().
split
()
yield
tarf
.
extractfile
(
tf
).
read
().
rstrip
(
six
.
b
(
"
\n\r
"
)).
translate
(
None
,
six
.
b
(
string
.
punctuation
)).
lower
().
split
()
tf
=
tarf
.
next
()
...
...
@@ -63,7 +64,7 @@ def build_dict(pattern, cutoff):
word_freq
[
word
]
+=
1
# Not sure if we should prune less-frequent words here.
word_freq
=
[
x
for
x
in
list
(
word_freq
.
items
()
)
if
x
[
1
]
>
cutoff
]
word_freq
=
[
x
for
x
in
six
.
moves
.
iteritems
(
word_freq
)
if
x
[
1
]
>
cutoff
]
dictionary
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
dictionary
))
...
...
python/paddle/dataset/imikolov.py
浏览文件 @
6dc07e7f
...
...
@@ -21,7 +21,7 @@ into paddle reader creators.
import
paddle.dataset.common
import
collections
import
tarfile
from
six.moves
import
range
import
six
__all__
=
[
'train'
,
'test'
,
'build_dict'
,
'convert'
]
...
...
@@ -65,11 +65,13 @@ def build_dict(min_word_freq=50):
# remove <unk> for now, since we will set it as last index
del
word_freq
[
'<unk>'
]
word_freq
=
[
x
for
x
in
list
(
word_freq
.
items
())
if
x
[
1
]
>
min_word_freq
]
word_freq
=
[
x
for
x
in
six
.
moves
.
iteritems
(
word_freq
)
if
x
[
1
]
>
min_word_freq
]
word_freq_sorted
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
word_freq_sorted
))
word_idx
=
dict
(
list
(
zip
(
words
,
range
(
len
(
words
)))))
word_idx
=
dict
(
list
(
zip
(
words
,
six
.
moves
.
range
(
len
(
words
)))))
word_idx
[
'<unk>'
]
=
len
(
words
)
return
word_idx
...
...
@@ -90,7 +92,7 @@ def reader_creator(filename, word_idx, n, data_type):
l
=
[
'<s>'
]
+
l
.
strip
().
split
()
+
[
'<e>'
]
if
len
(
l
)
>=
n
:
l
=
[
word_idx
.
get
(
w
,
UNK
)
for
w
in
l
]
for
i
in
range
(
n
,
len
(
l
)
+
1
):
for
i
in
six
.
moves
.
range
(
n
,
len
(
l
)
+
1
):
yield
tuple
(
l
[
i
-
n
:
i
])
elif
DataType
.
SEQ
==
data_type
:
l
=
l
.
strip
().
split
()
...
...
python/paddle/dataset/sentiment.py
浏览文件 @
6dc07e7f
...
...
@@ -20,6 +20,7 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK
TODO(yuyang18): Complete dataset.
"""
import
six
import
collections
from
itertools
import
chain
...
...
@@ -64,7 +65,7 @@ def get_word_dict():
for
field
in
movie_reviews
.
fileids
(
category
):
for
words
in
movie_reviews
.
words
(
field
):
word_freq_dict
[
words
]
+=
1
words_sort_list
=
list
(
word_freq_dict
.
items
()
)
words_sort_list
=
six
.
moves
.
iteritems
(
word_freq_dict
)
words_sort_list
.
sort
(
cmp
=
lambda
a
,
b
:
b
[
1
]
-
a
[
1
])
for
index
,
word
in
enumerate
(
words_sort_list
):
words_freq_sorted
.
append
((
word
[
0
],
index
))
...
...
python/paddle/dataset/wmt14.py
浏览文件 @
6dc07e7f
...
...
@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True):
tar_file
=
paddle
.
dataset
.
common
.
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
src_dict
,
trg_dict
=
__read_to_dict
(
tar_file
,
dict_size
)
if
reverse
:
src_dict
=
{
v
:
k
for
k
,
v
in
list
(
src_dict
.
items
()
)}
trg_dict
=
{
v
:
k
for
k
,
v
in
list
(
trg_dict
.
items
()
)}
src_dict
=
{
v
:
k
for
k
,
v
in
six
.
moves
.
iteritems
(
src_dict
)}
trg_dict
=
{
v
:
k
for
k
,
v
in
six
.
moves
.
iteritems
(
trg_dict
)}
return
src_dict
,
trg_dict
...
...
python/paddle/dataset/wmt16.py
浏览文件 @
6dc07e7f
...
...
@@ -72,7 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout
.
write
(
"%s
\n
%s
\n
%s
\n
"
%
(
START_MARK
,
END_MARK
,
UNK_MARK
))
for
idx
,
word
in
enumerate
(
sorted
(
iter
(
list
(
word_dict
.
items
())
),
six
.
moves
.
iteritems
(
word_dict
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)):
if
idx
+
3
==
dict_size
:
break
...
...
python/paddle/fluid/backward.py
浏览文件 @
6dc07e7f
...
...
@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
"""
op_desc
=
core
.
OpDesc
()
op_desc
.
set_type
(
op_type
)
for
para
,
args
in
list
(
inputs
.
items
()
):
for
para
,
args
in
six
.
moves
.
iteritems
(
inputs
):
op_desc
.
set_input
(
para
,
list
(
map
(
lambda
arg
:
arg
.
decode
()
if
isinstance
(
arg
,
six
.
binary_type
)
else
arg
,
args
)))
for
para
,
args
in
list
(
outputs
.
items
()
):
for
para
,
args
in
six
.
moves
.
iteritems
(
outputs
):
op_desc
.
set_output
(
para
,
list
(
...
...
@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
if
op_role_attr_name
not
in
attrs
:
attrs
[
op_role_attr_name
]
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
for
name
,
val
in
list
(
attrs
.
items
()
):
for
name
,
val
in
six
.
moves
.
iteritems
(
attrs
):
if
isinstance
(
val
,
framework
.
Block
):
op_desc
.
set_block_attr
(
name
,
val
.
desc
)
else
:
...
...
@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
op_desc
.
set_output
(
param_name
,
arg_names
)
renamed_vars
[
var_name
].
append
(
new_name
)
for
var_name
,
inputs
in
list
(
renamed_vars
.
items
()
):
for
var_name
,
inputs
in
six
.
moves
.
iteritems
(
renamed_vars
):
if
len
(
inputs
)
>
1
:
pending_sum_ops
.
append
(
(
_create_op_desc_
(
"sum"
,
{
"X"
:
inputs
},
{
"Out"
:
[
var_name
]},
...
...
@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc
.
rename_output
(
name
,
new_name
)
var_map
[
name
]
=
new_name
for
g
,
ng
in
list
(
var_map
.
items
()
):
for
g
,
ng
in
six
.
moves
.
iteritems
(
var_map
):
if
g
in
grad_to_var
:
grad_to_var
[
ng
]
=
grad_to_var
[
g
]
grad_to_var
.
pop
(
g
)
...
...
python/paddle/fluid/framework.py
浏览文件 @
6dc07e7f
...
...
@@ -958,7 +958,7 @@ class Block(object):
return
list
(
self
.
iter_parameters
())
def
iter_parameters
(
self
):
return
(
item
[
1
]
for
item
in
list
(
self
.
vars
.
items
()
)
return
(
item
[
1
]
for
item
in
six
.
moves
.
iteritems
(
self
.
vars
)
if
isinstance
(
item
[
1
],
Parameter
))
def
create_var
(
self
,
*
args
,
**
kwargs
):
...
...
python/paddle/fluid/graphviz.py
浏览文件 @
6dc07e7f
...
...
@@ -106,7 +106,7 @@ class Graph(object):
def
_rank_repr
(
self
):
ranks
=
sorted
(
list
(
self
.
rank_groups
.
items
()
),
six
.
moves
.
iteritems
(
self
.
rank_groups
),
key
=
functools
.
cmp_to_key
(
lambda
a
,
b
:
a
[
1
].
priority
>
b
[
1
].
priority
))
repr
=
[]
...
...
@@ -150,8 +150,9 @@ class Node(object):
reprs
=
'{name} [label={label} {extra} ];'
.
format
(
name
=
self
.
name
,
label
=
self
.
label
,
extra
=
','
+
','
.
join
(
"%s=%s"
%
(
key
,
crepr
(
value
))
for
key
,
value
in
list
(
self
.
attrs
.
items
()))
extra
=
','
+
','
.
join
(
"%s=%s"
%
(
key
,
crepr
(
value
))
for
key
,
value
in
six
.
moves
.
iteritems
(
self
.
attrs
))
if
self
.
attrs
else
""
)
return
reprs
...
...
@@ -175,7 +176,7 @@ class Edge(object):
target
=
self
.
target
.
name
,
extra
=
""
if
not
self
.
attrs
else
"["
+
','
.
join
(
"{}={}"
.
format
(
attr
[
0
],
crepr
(
attr
[
1
]))
for
attr
in
list
(
self
.
attrs
.
items
()
))
+
"]"
)
for
attr
in
six
.
moves
.
iteritems
(
self
.
attrs
))
+
"]"
)
return
repr
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
6dc07e7f
...
...
@@ -22,6 +22,7 @@ from ..initializer import force_init_on_cpu
from
.ops
import
logical_and
,
logical_not
,
logical_or
import
numpy
import
warnings
import
six
from
functools
import
reduce
__all__
=
[
...
...
@@ -602,7 +603,7 @@ class StaticRNN(object):
boot_memories
=
[]
pre_memories
=
[]
memories
=
[]
for
_
,
mem
in
list
(
self
.
memories
.
items
()
):
for
_
,
mem
in
six
.
moves
.
iteritems
(
self
.
memories
):
boot_memories
.
append
(
mem
.
init
)
pre_memories
.
append
(
mem
.
pre_mem
.
name
)
mem_var
=
rnn_block
.
var
(
mem
.
mem
.
name
)
...
...
python/paddle/fluid/metrics.py
浏览文件 @
6dc07e7f
...
...
@@ -14,11 +14,12 @@
"""
Fluid Metrics
The metrics are accomplished via Python natively.
The metrics are accomplished via Python natively.
"""
import
numpy
as
np
import
copy
import
warnings
import
six
__all__
=
[
'MetricBase'
,
...
...
@@ -79,10 +80,10 @@ class MetricBase(object):
"""
states
=
{
attr
:
value
for
attr
,
value
in
list
(
self
.
__dict__
.
items
()
)
for
attr
,
value
in
six
.
moves
.
iteritems
(
self
.
__dict__
)
if
not
attr
.
startswith
(
"_"
)
}
for
attr
,
value
in
list
(
states
.
items
()
):
for
attr
,
value
in
six
.
moves
.
iteritems
(
states
):
if
isinstance
(
value
,
int
):
setattr
(
self
,
attr
,
0
)
elif
isinstance
(
value
,
float
):
...
...
@@ -105,7 +106,7 @@ class MetricBase(object):
"""
states
=
{
attr
:
value
for
attr
,
value
in
list
(
self
.
__dict__
.
items
()
)
for
attr
,
value
in
six
.
moves
.
iteritems
(
self
.
__dict__
)
if
not
attr
.
startswith
(
"_"
)
}
config
=
{}
...
...
@@ -141,10 +142,10 @@ class CompositeMetric(MetricBase):
"""
Composite multiple metrics in one instance.
for example, merge F1, accuracy, recall into one Metric.
Examples:
.. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
...
...
python/paddle/fluid/tests/unittests/benchmark.py
浏览文件 @
6dc07e7f
...
...
@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest):
def
_get_input_names
(
self
):
inputs
=
[]
for
name
,
value
in
list
(
self
.
inputs
.
items
()
):
for
name
,
value
in
six
.
moves
.
iteritems
(
self
.
inputs
):
if
isinstance
(
value
,
list
):
inputs
.
extend
([
sub_name
for
sub_name
,
_
in
value
])
inputs
.
append
(
name
)
...
...
@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest):
def
_get_output_names
(
self
):
outputs
=
[]
for
var_name
,
var
in
list
(
self
.
outputs
.
items
()
):
for
var_name
,
var
in
six
.
moves
.
iteritems
(
self
.
outputs
):
if
isinstance
(
var
,
list
):
for
sub_var_name
,
sub_var
in
var
:
outputs
.
append
(
sub_var_name
)
...
...
python/paddle/fluid/tests/unittests/test_detection_map_op.py
浏览文件 @
6dc07e7f
...
...
@@ -14,6 +14,7 @@
import
unittest
import
numpy
as
np
import
six
import
sys
import
collections
import
math
...
...
@@ -176,7 +177,7 @@ class TestDetectionMAPOp(OpTest):
true_pos
[
label
].
append
([
score
,
tp
])
false_pos
[
label
].
append
([
score
,
fp
])
for
(
label
,
label_pos_num
)
in
list
(
label_count
.
items
()
):
for
(
label
,
label_pos_num
)
in
six
.
moves
.
iteritems
(
label_count
):
if
label_pos_num
==
0
or
label
not
in
true_pos
:
continue
label_true_pos
=
true_pos
[
label
]
label_false_pos
=
false_pos
[
label
]
...
...
python/paddle/fluid/tests/unittests/test_lod_rank_table.py
浏览文件 @
6dc07e7f
...
...
@@ -18,6 +18,7 @@ from paddle.fluid.executor import Executor
import
paddle.fluid.core
as
core
import
numpy
import
unittest
import
six
class
TestLoDRankTable
(
unittest
.
TestCase
):
...
...
@@ -36,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase):
exe
.
run
(
scope
=
scope
,
feed
=
{
'x'
:
tensor
})
var
=
scope
.
find_var
(
rank_table
.
name
)
table
=
var
.
get_lod_rank_table
()
self
.
assertEqual
([(
0
,
5
),
(
1
,
1
),
(
2
,
1
)],
list
(
table
.
items
()
))
self
.
assertEqual
([(
0
,
5
),
(
1
,
1
),
(
2
,
1
)],
six
.
moves
.
iteritems
(
table
))
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_positive_negative_pair_op.py
浏览文件 @
6dc07e7f
...
...
@@ -15,6 +15,7 @@
import
unittest
import
itertools
import
numpy
as
np
import
six
from
op_test
import
OpTest
...
...
@@ -32,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None):
# accumulate statistics
pos
,
neg
,
neu
=
0
,
0
,
0
for
_
,
ranks
in
list
(
predictions
.
items
()
):
for
_
,
ranks
in
six
.
moves
.
iteritems
(
predictions
):
for
e1
,
e2
in
itertools
.
combinations
(
ranks
,
2
):
s1
,
s2
,
l1
,
l2
,
w1
,
w2
=
e1
[
0
],
e2
[
0
],
e1
[
1
],
e2
[
1
],
e1
[
2
],
e2
[
2
]
w
=
(
w1
+
w2
)
*
0.5
...
...
python/paddle/fluid/trainer.py
浏览文件 @
6dc07e7f
...
...
@@ -16,6 +16,7 @@ import contextlib
import
os
import
errno
import
shutil
import
six
import
time
from
.
import
core
...
...
@@ -618,7 +619,7 @@ def build_feed_var_list(program, feed_order):
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
)
sorted_pair_list
=
sorted
(
list
(
feed_order
.
items
()
),
key
=
lambda
item
:
item
[
1
])
six
.
moves
.
iteritems
(
feed_order
),
key
=
lambda
item
:
item
[
1
])
feed_var_list
=
[
program
.
global_block
().
var
(
pair
[
0
])
for
pair
in
sorted_pair_list
]
...
...
@@ -1036,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
cur_dir
=
_get_trainer_dir
(
dirname
,
trainer_id
)
for
name
,
value
in
list
(
trainer_args
.
items
()
):
for
name
,
value
in
six
.
moves
.
iteritems
(
trainer_args
):
args_file
=
os
.
path
.
join
(
cur_dir
,
name
)
with
open
(
args_file
,
'w'
)
as
f
:
f
.
write
(
str
(
value
))
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
6dc07e7f
...
...
@@ -218,7 +218,8 @@ class DistributeTranspiler(object):
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above
grad_var_mapping_items
=
list
(
self
.
grad_var_mapping
.
items
())
grad_var_mapping_items
=
list
(
six
.
moves
.
iteritems
(
self
.
grad_var_mapping
))
if
not
self
.
config
.
slice_var_up
:
random
.
seed
(
self
.
origin_program
.
random_seed
)
...
...
@@ -279,7 +280,7 @@ class DistributeTranspiler(object):
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
list
(
self
.
param_var_mapping
.
items
()
):
for
varname
,
splited_var
in
six
.
moves
.
iteritems
(
self
.
param_var_mapping
):
eps
=
[]
for
var
in
splited_var
:
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
...
...
@@ -303,7 +304,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
for
varname
,
splited_var
in
list
(
self
.
param_var_mapping
.
items
()
):
for
varname
,
splited_var
in
six
.
moves
.
iteritems
(
self
.
param_var_mapping
):
if
len
(
splited_var
)
<=
1
:
continue
orig_param
=
program
.
global_block
().
vars
[
varname
]
...
...
@@ -560,7 +561,7 @@ class DistributeTranspiler(object):
# 1. create vars in pserver program to startup program
pserver_vars
=
pserver_program
.
global_block
().
vars
created_var_map
=
collections
.
OrderedDict
()
for
_
,
var
in
list
(
pserver_vars
.
items
()
):
for
_
,
var
in
six
.
moves
.
iteritems
(
pserver_vars
):
tmpvar
=
s_prog
.
global_block
().
_clone_variable
(
var
)
created_var_map
[
var
.
name
]
=
tmpvar
...
...
@@ -997,7 +998,7 @@ class DistributeTranspiler(object):
block_map
[
varname
]
=
[]
block_map
[
varname
].
append
((
int
(
offset
),
int
(
size
)))
for
varname
,
splited
in
list
(
block_map
.
items
()
):
for
varname
,
splited
in
six
.
moves
.
iteritems
(
block_map
):
orig_var
=
program
.
global_block
().
var
(
varname
)
if
len
(
splited
)
==
1
:
if
self
.
sync_mode
and
add_trainer_suffix
:
...
...
@@ -1248,9 +1249,7 @@ class DistributeTranspiler(object):
def
_is_splited_grad_var
(
self
,
var
,
var_dict
):
grad_block
=
None
# TODO(minqiyang): replace these items() with six.iteritems() to
# improve memory
for
_
,
g
in
list
(
var_dict
.
items
()):
for
_
,
g
in
six
.
moves
.
iteritems
(
var_dict
):
if
self
.
_orig_varname
(
g
.
name
)
==
self
.
_orig_varname
(
var
.
name
):
if
g
.
name
.
find
(
".trainer_"
)
==
-
1
:
grad_block
=
g
...
...
@@ -1260,7 +1259,7 @@ class DistributeTranspiler(object):
def
_clone_lr_op
(
self
,
program
,
block
,
op
):
inputs
=
self
.
_get_input_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
op
)
for
key
,
varlist
in
list
(
inputs
.
items
()
):
for
key
,
varlist
in
six
.
moves
.
iteritems
(
inputs
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
var
in
varlist
:
...
...
@@ -1269,7 +1268,7 @@ class DistributeTranspiler(object):
outputs
=
self
.
_get_output_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
op
)
for
key
,
varlist
in
list
(
outputs
.
items
()
):
for
key
,
varlist
in
six
.
moves
.
iteritems
(
outputs
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
var
in
varlist
:
...
...
@@ -1284,7 +1283,7 @@ class DistributeTranspiler(object):
# Append the ops for parameters that do not need to be optimized/updated
inputs
=
self
.
_get_input_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
for
key
,
varlist
in
list
(
inputs
.
items
()
):
for
key
,
varlist
in
six
.
moves
.
iteritems
(
inputs
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
var
in
varlist
:
...
...
@@ -1303,7 +1302,7 @@ class DistributeTranspiler(object):
outputs
=
self
.
_get_output_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
for
key
,
varlist
in
list
(
outputs
.
items
()
):
for
key
,
varlist
in
six
.
moves
.
iteritems
(
outputs
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
var
in
varlist
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录