提交 5d4238cd 编写于 作者: M minqiyang

Fix six.iteritems problem

上级 e4e9450e
......@@ -64,7 +64,7 @@ def build_dict(pattern, cutoff):
word_freq[word] += 1
# Not sure if we should prune less-frequent words here.
word_freq = [x for x in six.moves.iteritems(word_freq) if x[1] > cutoff]
word_freq = [x for x in six.iteritems(word_freq) if x[1] > cutoff]
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
......
......@@ -66,7 +66,7 @@ def build_dict(min_word_freq=50):
del word_freq['<unk>']
word_freq = [
x for x in six.moves.iteritems(word_freq) if x[1] > min_word_freq
x for x in six.iteritems(word_freq) if x[1] > min_word_freq
]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
......
......@@ -65,7 +65,7 @@ def get_word_dict():
for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field):
word_freq_dict[words] += 1
words_sort_list = six.moves.iteritems(word_freq_dict)
words_sort_list = six.iteritems(word_freq_dict)
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index))
......
......@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True):
tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse:
src_dict = {v: k for k, v in six.moves.iteritems(src_dict)}
trg_dict = {v: k for k, v in six.moves.iteritems(trg_dict)}
src_dict = {v: k for k, v in six.iteritems(src_dict)}
trg_dict = {v: k for k, v in six.iteritems(trg_dict)}
return src_dict, trg_dict
......
......@@ -72,8 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))
for idx, word in enumerate(
sorted(
six.moves.iteritems(word_dict),
key=lambda x: x[1],
six.iteritems(word_dict), key=lambda x: x[1],
reverse=True)):
if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0]))
......
......@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
"""
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for para, args in six.moves.iteritems(inputs):
for para, args in six.iteritems(inputs):
op_desc.set_input(
para,
list(
map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg,
args)))
for para, args in six.moves.iteritems(outputs):
for para, args in six.iteritems(outputs):
op_desc.set_output(
para,
list(
......@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
if op_role_attr_name not in attrs:
attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in six.moves.iteritems(attrs):
for name, val in six.iteritems(attrs):
if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc)
else:
......@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name)
for var_name, inputs in six.moves.iteritems(renamed_vars):
for var_name, inputs in six.iteritems(renamed_vars):
if len(inputs) > 1:
pending_sum_ops.append(
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]},
......@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc.rename_output(name, new_name)
var_map[name] = new_name
for g, ng in six.moves.iteritems(var_map):
for g, ng in six.iteritems(var_map):
if g in grad_to_var:
grad_to_var[ng] = grad_to_var[g]
grad_to_var.pop(g)
......
......@@ -958,7 +958,7 @@ class Block(object):
return list(self.iter_parameters())
def iter_parameters(self):
return (item[1] for item in six.moves.iteritems(self.vars)
return (item[1] for item in six.iteritems(self.vars)
if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs):
......
......@@ -106,7 +106,7 @@ class Graph(object):
def _rank_repr(self):
ranks = sorted(
six.moves.iteritems(self.rank_groups),
six.iteritems(self.rank_groups),
key=functools.cmp_to_key(
lambda a, b: a[1].priority > b[1].priority))
repr = []
......@@ -150,9 +150,8 @@ class Node(object):
reprs = '{name} [label={label} {extra} ];'.format(
name=self.name,
label=self.label,
extra=',' + ','.join(
"%s=%s" % (key, crepr(value))
for key, value in six.moves.iteritems(self.attrs))
extra=',' + ','.join("%s=%s" % (key, crepr(value))
for key, value in six.iteritems(self.attrs))
if self.attrs else "")
return reprs
......@@ -176,7 +175,7 @@ class Edge(object):
target=self.target.name,
extra="" if not self.attrs else
"[" + ','.join("{}={}".format(attr[0], crepr(attr[1]))
for attr in six.moves.iteritems(self.attrs)) + "]")
for attr in six.iteritems(self.attrs)) + "]")
return repr
......
......@@ -603,7 +603,7 @@ class StaticRNN(object):
boot_memories = []
pre_memories = []
memories = []
for _, mem in six.moves.iteritems(self.memories):
for _, mem in six.iteritems(self.memories):
boot_memories.append(mem.init)
pre_memories.append(mem.pre_mem.name)
mem_var = rnn_block.var(mem.mem.name)
......
......@@ -80,10 +80,10 @@ class MetricBase(object):
"""
states = {
attr: value
for attr, value in six.moves.iteritems(self.__dict__)
for attr, value in six.iteritems(self.__dict__)
if not attr.startswith("_")
}
for attr, value in six.moves.iteritems(states):
for attr, value in six.iteritems(states):
if isinstance(value, int):
setattr(self, attr, 0)
elif isinstance(value, float):
......@@ -106,7 +106,7 @@ class MetricBase(object):
"""
states = {
attr: value
for attr, value in six.moves.iteritems(self.__dict__)
for attr, value in six.iteritems(self.__dict__)
if not attr.startswith("_")
}
config = {}
......
......@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest):
def _get_input_names(self):
inputs = []
for name, value in six.moves.iteritems(self.inputs):
for name, value in six.iteritems(self.inputs):
if isinstance(value, list):
inputs.extend([sub_name for sub_name, _ in value])
inputs.append(name)
......@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest):
def _get_output_names(self):
outputs = []
for var_name, var in six.moves.iteritems(self.outputs):
for var_name, var in six.iteritems(self.outputs):
if isinstance(var, list):
for sub_var_name, sub_var in var:
outputs.append(sub_var_name)
......
......@@ -177,7 +177,7 @@ class TestDetectionMAPOp(OpTest):
true_pos[label].append([score, tp])
false_pos[label].append([score, fp])
for (label, label_pos_num) in six.moves.iteritems(label_count):
for (label, label_pos_num) in six.iteritems(label_count):
if label_pos_num == 0 or label not in true_pos: continue
label_true_pos = true_pos[label]
label_false_pos = false_pos[label]
......
......@@ -37,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase):
exe.run(scope=scope, feed={'x': tensor})
var = scope.find_var(rank_table.name)
table = var.get_lod_rank_table()
self.assertEqual([(0, 5), (1, 1), (2, 1)], six.moves.iteritems(table))
self.assertEqual([(0, 5), (1, 1), (2, 1)], six.iteritems(table))
if __name__ == '__main__':
......
......@@ -33,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None):
# accumulate statistics
pos, neg, neu = 0, 0, 0
for _, ranks in six.moves.iteritems(predictions):
for _, ranks in six.iteritems(predictions):
for e1, e2 in itertools.combinations(ranks, 2):
s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
w = (w1 + w2) * 0.5
......
......@@ -619,7 +619,7 @@ def build_feed_var_list(program, feed_order):
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
)
sorted_pair_list = sorted(
six.moves.iteritems(feed_order), key=lambda item: item[1])
six.iteritems(feed_order), key=lambda item: item[1])
feed_var_list = [
program.global_block().var(pair[0]) for pair in sorted_pair_list
]
......@@ -1037,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in six.moves.iteritems(trainer_args):
for name, value in six.iteritems(trainer_args):
args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f:
f.write(str(value))
......
......@@ -218,8 +218,7 @@ class DistributeTranspiler(object):
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = list(
six.moves.iteritems(self.grad_var_mapping))
grad_var_mapping_items = list(six.iteritems(self.grad_var_mapping))
if not self.config.slice_var_up:
random.seed(self.origin_program.random_seed)
......@@ -280,7 +279,7 @@ class DistributeTranspiler(object):
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
for varname, splited_var in six.iteritems(self.param_var_mapping):
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
......@@ -304,7 +303,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
for varname, splited_var in six.iteritems(self.param_var_mapping):
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
......@@ -561,7 +560,7 @@ class DistributeTranspiler(object):
# 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars
created_var_map = collections.OrderedDict()
for _, var in six.moves.iteritems(pserver_vars):
for _, var in six.iteritems(pserver_vars):
tmpvar = s_prog.global_block()._clone_variable(var)
created_var_map[var.name] = tmpvar
......@@ -998,7 +997,7 @@ class DistributeTranspiler(object):
block_map[varname] = []
block_map[varname].append((int(offset), int(size)))
for varname, splited in six.moves.iteritems(block_map):
for varname, splited in six.iteritems(block_map):
orig_var = program.global_block().var(varname)
if len(splited) == 1:
if self.sync_mode and add_trainer_suffix:
......@@ -1249,7 +1248,7 @@ class DistributeTranspiler(object):
def _is_splited_grad_var(self, var, var_dict):
grad_block = None
for _, g in six.moves.iteritems(var_dict):
for _, g in six.iteritems(var_dict):
if self._orig_varname(g.name) == self._orig_varname(var.name):
if g.name.find(".trainer_") == -1:
grad_block = g
......@@ -1259,7 +1258,7 @@ class DistributeTranspiler(object):
def _clone_lr_op(self, program, block, op):
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in six.moves.iteritems(inputs):
for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1268,7 +1267,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in six.moves.iteritems(outputs):
for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1283,7 +1282,7 @@ class DistributeTranspiler(object):
# Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op)
for key, varlist in six.moves.iteritems(inputs):
for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1302,7 +1301,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
for key, varlist in six.moves.iteritems(outputs):
for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册