提交 6dc07e7f 编写于 作者: M minqiyang

Replace items() with six.moves.iteritems() to improve memory usage

上级 9cd59990
...@@ -47,7 +47,8 @@ def tokenize(pattern): ...@@ -47,7 +47,8 @@ def tokenize(pattern):
while tf != None: while tf != None:
if bool(pattern.match(tf.name)): if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization. # newline and punctuations removal and ad-hoc tokenization.
yield tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate( yield tarf.extractfile(tf).read().rstrip(six.b(
"\n\r")).translate(
None, six.b(string.punctuation)).lower().split() None, six.b(string.punctuation)).lower().split()
tf = tarf.next() tf = tarf.next()
...@@ -63,7 +64,7 @@ def build_dict(pattern, cutoff): ...@@ -63,7 +64,7 @@ def build_dict(pattern, cutoff):
word_freq[word] += 1 word_freq[word] += 1
# Not sure if we should prune less-frequent words here. # Not sure if we should prune less-frequent words here.
word_freq = [x for x in list(word_freq.items()) if x[1] > cutoff] word_freq = [x for x in six.moves.iteritems(word_freq) if x[1] > cutoff]
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary)) words, _ = list(zip(*dictionary))
......
...@@ -21,7 +21,7 @@ into paddle reader creators. ...@@ -21,7 +21,7 @@ into paddle reader creators.
import paddle.dataset.common import paddle.dataset.common
import collections import collections
import tarfile import tarfile
from six.moves import range import six
__all__ = ['train', 'test', 'build_dict', 'convert'] __all__ = ['train', 'test', 'build_dict', 'convert']
...@@ -65,11 +65,13 @@ def build_dict(min_word_freq=50): ...@@ -65,11 +65,13 @@ def build_dict(min_word_freq=50):
# remove <unk> for now, since we will set it as last index # remove <unk> for now, since we will set it as last index
del word_freq['<unk>'] del word_freq['<unk>']
word_freq = [x for x in list(word_freq.items()) if x[1] > min_word_freq] word_freq = [
x for x in six.moves.iteritems(word_freq) if x[1] > min_word_freq
]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted)) words, _ = list(zip(*word_freq_sorted))
word_idx = dict(list(zip(words, range(len(words))))) word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words) word_idx['<unk>'] = len(words)
return word_idx return word_idx
...@@ -90,7 +92,7 @@ def reader_creator(filename, word_idx, n, data_type): ...@@ -90,7 +92,7 @@ def reader_creator(filename, word_idx, n, data_type):
l = ['<s>'] + l.strip().split() + ['<e>'] l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n: if len(l) >= n:
l = [word_idx.get(w, UNK) for w in l] l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1): for i in six.moves.range(n, len(l) + 1):
yield tuple(l[i - n:i]) yield tuple(l[i - n:i])
elif DataType.SEQ == data_type: elif DataType.SEQ == data_type:
l = l.strip().split() l = l.strip().split()
......
...@@ -20,6 +20,7 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK ...@@ -20,6 +20,7 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK
TODO(yuyang18): Complete dataset. TODO(yuyang18): Complete dataset.
""" """
import six
import collections import collections
from itertools import chain from itertools import chain
...@@ -64,7 +65,7 @@ def get_word_dict(): ...@@ -64,7 +65,7 @@ def get_word_dict():
for field in movie_reviews.fileids(category): for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field): for words in movie_reviews.words(field):
word_freq_dict[words] += 1 word_freq_dict[words] += 1
words_sort_list = list(word_freq_dict.items()) words_sort_list = six.moves.iteritems(word_freq_dict)
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list): for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index)) words_freq_sorted.append((word[0], index))
......
...@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True): ...@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True):
tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse: if reverse:
src_dict = {v: k for k, v in list(src_dict.items())} src_dict = {v: k for k, v in six.moves.iteritems(src_dict)}
trg_dict = {v: k for k, v in list(trg_dict.items())} trg_dict = {v: k for k, v in six.moves.iteritems(trg_dict)}
return src_dict, trg_dict return src_dict, trg_dict
......
...@@ -72,7 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -72,7 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))
for idx, word in enumerate( for idx, word in enumerate(
sorted( sorted(
iter(list(word_dict.items())), six.moves.iteritems(word_dict),
key=lambda x: x[1], key=lambda x: x[1],
reverse=True)): reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
......
...@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
""" """
op_desc = core.OpDesc() op_desc = core.OpDesc()
op_desc.set_type(op_type) op_desc.set_type(op_type)
for para, args in list(inputs.items()): for para, args in six.moves.iteritems(inputs):
op_desc.set_input( op_desc.set_input(
para, para,
list( list(
map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg, map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg,
args))) args)))
for para, args in list(outputs.items()): for para, args in six.moves.iteritems(outputs):
op_desc.set_output( op_desc.set_output(
para, para,
list( list(
...@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
if op_role_attr_name not in attrs: if op_role_attr_name not in attrs:
attrs[ attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in list(attrs.items()): for name, val in six.moves.iteritems(attrs):
if isinstance(val, framework.Block): if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc) op_desc.set_block_attr(name, val.desc)
else: else:
...@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
op_desc.set_output(param_name, arg_names) op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
for var_name, inputs in list(renamed_vars.items()): for var_name, inputs in six.moves.iteritems(renamed_vars):
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append( pending_sum_ops.append(
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]}, (_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]},
...@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): ...@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc.rename_output(name, new_name) op_desc.rename_output(name, new_name)
var_map[name] = new_name var_map[name] = new_name
for g, ng in list(var_map.items()): for g, ng in six.moves.iteritems(var_map):
if g in grad_to_var: if g in grad_to_var:
grad_to_var[ng] = grad_to_var[g] grad_to_var[ng] = grad_to_var[g]
grad_to_var.pop(g) grad_to_var.pop(g)
......
...@@ -958,7 +958,7 @@ class Block(object): ...@@ -958,7 +958,7 @@ class Block(object):
return list(self.iter_parameters()) return list(self.iter_parameters())
def iter_parameters(self): def iter_parameters(self):
return (item[1] for item in list(self.vars.items()) return (item[1] for item in six.moves.iteritems(self.vars)
if isinstance(item[1], Parameter)) if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
......
...@@ -106,7 +106,7 @@ class Graph(object): ...@@ -106,7 +106,7 @@ class Graph(object):
def _rank_repr(self): def _rank_repr(self):
ranks = sorted( ranks = sorted(
list(self.rank_groups.items()), six.moves.iteritems(self.rank_groups),
key=functools.cmp_to_key( key=functools.cmp_to_key(
lambda a, b: a[1].priority > b[1].priority)) lambda a, b: a[1].priority > b[1].priority))
repr = [] repr = []
...@@ -150,8 +150,9 @@ class Node(object): ...@@ -150,8 +150,9 @@ class Node(object):
reprs = '{name} [label={label} {extra} ];'.format( reprs = '{name} [label={label} {extra} ];'.format(
name=self.name, name=self.name,
label=self.label, label=self.label,
extra=',' + ','.join("%s=%s" % (key, crepr(value)) extra=',' + ','.join(
for key, value in list(self.attrs.items())) "%s=%s" % (key, crepr(value))
for key, value in six.moves.iteritems(self.attrs))
if self.attrs else "") if self.attrs else "")
return reprs return reprs
...@@ -175,7 +176,7 @@ class Edge(object): ...@@ -175,7 +176,7 @@ class Edge(object):
target=self.target.name, target=self.target.name,
extra="" if not self.attrs else extra="" if not self.attrs else
"[" + ','.join("{}={}".format(attr[0], crepr(attr[1])) "[" + ','.join("{}={}".format(attr[0], crepr(attr[1]))
for attr in list(self.attrs.items())) + "]") for attr in six.moves.iteritems(self.attrs)) + "]")
return repr return repr
......
...@@ -22,6 +22,7 @@ from ..initializer import force_init_on_cpu ...@@ -22,6 +22,7 @@ from ..initializer import force_init_on_cpu
from .ops import logical_and, logical_not, logical_or from .ops import logical_and, logical_not, logical_or
import numpy import numpy
import warnings import warnings
import six
from functools import reduce from functools import reduce
__all__ = [ __all__ = [
...@@ -602,7 +603,7 @@ class StaticRNN(object): ...@@ -602,7 +603,7 @@ class StaticRNN(object):
boot_memories = [] boot_memories = []
pre_memories = [] pre_memories = []
memories = [] memories = []
for _, mem in list(self.memories.items()): for _, mem in six.moves.iteritems(self.memories):
boot_memories.append(mem.init) boot_memories.append(mem.init)
pre_memories.append(mem.pre_mem.name) pre_memories.append(mem.pre_mem.name)
mem_var = rnn_block.var(mem.mem.name) mem_var = rnn_block.var(mem.mem.name)
......
...@@ -19,6 +19,7 @@ The metrics are accomplished via Python natively. ...@@ -19,6 +19,7 @@ The metrics are accomplished via Python natively.
import numpy as np import numpy as np
import copy import copy
import warnings import warnings
import six
__all__ = [ __all__ = [
'MetricBase', 'MetricBase',
...@@ -79,10 +80,10 @@ class MetricBase(object): ...@@ -79,10 +80,10 @@ class MetricBase(object):
""" """
states = { states = {
attr: value attr: value
for attr, value in list(self.__dict__.items()) for attr, value in six.moves.iteritems(self.__dict__)
if not attr.startswith("_") if not attr.startswith("_")
} }
for attr, value in list(states.items()): for attr, value in six.moves.iteritems(states):
if isinstance(value, int): if isinstance(value, int):
setattr(self, attr, 0) setattr(self, attr, 0)
elif isinstance(value, float): elif isinstance(value, float):
...@@ -105,7 +106,7 @@ class MetricBase(object): ...@@ -105,7 +106,7 @@ class MetricBase(object):
""" """
states = { states = {
attr: value attr: value
for attr, value in list(self.__dict__.items()) for attr, value in six.moves.iteritems(self.__dict__)
if not attr.startswith("_") if not attr.startswith("_")
} }
config = {} config = {}
......
...@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest): ...@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest):
def _get_input_names(self): def _get_input_names(self):
inputs = [] inputs = []
for name, value in list(self.inputs.items()): for name, value in six.moves.iteritems(self.inputs):
if isinstance(value, list): if isinstance(value, list):
inputs.extend([sub_name for sub_name, _ in value]) inputs.extend([sub_name for sub_name, _ in value])
inputs.append(name) inputs.append(name)
...@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest): ...@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest):
def _get_output_names(self): def _get_output_names(self):
outputs = [] outputs = []
for var_name, var in list(self.outputs.items()): for var_name, var in six.moves.iteritems(self.outputs):
if isinstance(var, list): if isinstance(var, list):
for sub_var_name, sub_var in var: for sub_var_name, sub_var in var:
outputs.append(sub_var_name) outputs.append(sub_var_name)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import unittest import unittest
import numpy as np import numpy as np
import six
import sys import sys
import collections import collections
import math import math
...@@ -176,7 +177,7 @@ class TestDetectionMAPOp(OpTest): ...@@ -176,7 +177,7 @@ class TestDetectionMAPOp(OpTest):
true_pos[label].append([score, tp]) true_pos[label].append([score, tp])
false_pos[label].append([score, fp]) false_pos[label].append([score, fp])
for (label, label_pos_num) in list(label_count.items()): for (label, label_pos_num) in six.moves.iteritems(label_count):
if label_pos_num == 0 or label not in true_pos: continue if label_pos_num == 0 or label not in true_pos: continue
label_true_pos = true_pos[label] label_true_pos = true_pos[label]
label_false_pos = false_pos[label] label_false_pos = false_pos[label]
......
...@@ -18,6 +18,7 @@ from paddle.fluid.executor import Executor ...@@ -18,6 +18,7 @@ from paddle.fluid.executor import Executor
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy import numpy
import unittest import unittest
import six
class TestLoDRankTable(unittest.TestCase): class TestLoDRankTable(unittest.TestCase):
...@@ -36,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase): ...@@ -36,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase):
exe.run(scope=scope, feed={'x': tensor}) exe.run(scope=scope, feed={'x': tensor})
var = scope.find_var(rank_table.name) var = scope.find_var(rank_table.name)
table = var.get_lod_rank_table() table = var.get_lod_rank_table()
self.assertEqual([(0, 5), (1, 1), (2, 1)], list(table.items())) self.assertEqual([(0, 5), (1, 1), (2, 1)], six.moves.iteritems(table))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import itertools import itertools
import numpy as np import numpy as np
import six
from op_test import OpTest from op_test import OpTest
...@@ -32,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None): ...@@ -32,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None):
# accumulate statistics # accumulate statistics
pos, neg, neu = 0, 0, 0 pos, neg, neu = 0, 0, 0
for _, ranks in list(predictions.items()): for _, ranks in six.moves.iteritems(predictions):
for e1, e2 in itertools.combinations(ranks, 2): for e1, e2 in itertools.combinations(ranks, 2):
s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2] s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
w = (w1 + w2) * 0.5 w = (w1 + w2) * 0.5
......
...@@ -16,6 +16,7 @@ import contextlib ...@@ -16,6 +16,7 @@ import contextlib
import os import os
import errno import errno
import shutil import shutil
import six
import time import time
from . import core from . import core
...@@ -618,7 +619,7 @@ def build_feed_var_list(program, feed_order): ...@@ -618,7 +619,7 @@ def build_feed_var_list(program, feed_order):
"The values of 'feed_order' should be a permutation of [0, len(feed_order))" "The values of 'feed_order' should be a permutation of [0, len(feed_order))"
) )
sorted_pair_list = sorted( sorted_pair_list = sorted(
list(feed_order.items()), key=lambda item: item[1]) six.moves.iteritems(feed_order), key=lambda item: item[1])
feed_var_list = [ feed_var_list = [
program.global_block().var(pair[0]) for pair in sorted_pair_list program.global_block().var(pair[0]) for pair in sorted_pair_list
] ]
...@@ -1036,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args): ...@@ -1036,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
cur_dir = _get_trainer_dir(dirname, trainer_id) cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in list(trainer_args.items()): for name, value in six.moves.iteritems(trainer_args):
args_file = os.path.join(cur_dir, name) args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f: with open(args_file, 'w') as f:
f.write(str(value)) f.write(str(value))
......
...@@ -218,7 +218,8 @@ class DistributeTranspiler(object): ...@@ -218,7 +218,8 @@ class DistributeTranspiler(object):
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1 # fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above # shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = list(self.grad_var_mapping.items()) grad_var_mapping_items = list(
six.moves.iteritems(self.grad_var_mapping))
if not self.config.slice_var_up: if not self.config.slice_var_up:
random.seed(self.origin_program.random_seed) random.seed(self.origin_program.random_seed)
...@@ -279,7 +280,7 @@ class DistributeTranspiler(object): ...@@ -279,7 +280,7 @@ class DistributeTranspiler(object):
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in list(self.param_var_mapping.items()): for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
eps = [] eps = []
for var in splited_var: for var in splited_var:
index = [v.name for v in recv_vars].index(var.name) index = [v.name for v in recv_vars].index(var.name)
...@@ -303,7 +304,7 @@ class DistributeTranspiler(object): ...@@ -303,7 +304,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
for varname, splited_var in list(self.param_var_mapping.items()): for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
orig_param = program.global_block().vars[varname] orig_param = program.global_block().vars[varname]
...@@ -560,7 +561,7 @@ class DistributeTranspiler(object): ...@@ -560,7 +561,7 @@ class DistributeTranspiler(object):
# 1. create vars in pserver program to startup program # 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars pserver_vars = pserver_program.global_block().vars
created_var_map = collections.OrderedDict() created_var_map = collections.OrderedDict()
for _, var in list(pserver_vars.items()): for _, var in six.moves.iteritems(pserver_vars):
tmpvar = s_prog.global_block()._clone_variable(var) tmpvar = s_prog.global_block()._clone_variable(var)
created_var_map[var.name] = tmpvar created_var_map[var.name] = tmpvar
...@@ -997,7 +998,7 @@ class DistributeTranspiler(object): ...@@ -997,7 +998,7 @@ class DistributeTranspiler(object):
block_map[varname] = [] block_map[varname] = []
block_map[varname].append((int(offset), int(size))) block_map[varname].append((int(offset), int(size)))
for varname, splited in list(block_map.items()): for varname, splited in six.moves.iteritems(block_map):
orig_var = program.global_block().var(varname) orig_var = program.global_block().var(varname)
if len(splited) == 1: if len(splited) == 1:
if self.sync_mode and add_trainer_suffix: if self.sync_mode and add_trainer_suffix:
...@@ -1248,9 +1249,7 @@ class DistributeTranspiler(object): ...@@ -1248,9 +1249,7 @@ class DistributeTranspiler(object):
def _is_splited_grad_var(self, var, var_dict): def _is_splited_grad_var(self, var, var_dict):
grad_block = None grad_block = None
# TODO(minqiyang): replace these items() with six.iteritems() to for _, g in six.moves.iteritems(var_dict):
# improve memory
for _, g in list(var_dict.items()):
if self._orig_varname(g.name) == self._orig_varname(var.name): if self._orig_varname(g.name) == self._orig_varname(var.name):
if g.name.find(".trainer_") == -1: if g.name.find(".trainer_") == -1:
grad_block = g grad_block = g
...@@ -1260,7 +1259,7 @@ class DistributeTranspiler(object): ...@@ -1260,7 +1259,7 @@ class DistributeTranspiler(object):
def _clone_lr_op(self, program, block, op): def _clone_lr_op(self, program, block, op):
inputs = self._get_input_map_from_op( inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op) self.origin_program.global_block().vars, op)
for key, varlist in list(inputs.items()): for key, varlist in six.moves.iteritems(inputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
...@@ -1269,7 +1268,7 @@ class DistributeTranspiler(object): ...@@ -1269,7 +1268,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op) self.origin_program.global_block().vars, op)
for key, varlist in list(outputs.items()): for key, varlist in six.moves.iteritems(outputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
...@@ -1284,7 +1283,7 @@ class DistributeTranspiler(object): ...@@ -1284,7 +1283,7 @@ class DistributeTranspiler(object):
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op( inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in list(inputs.items()): for key, varlist in six.moves.iteritems(inputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
...@@ -1303,7 +1302,7 @@ class DistributeTranspiler(object): ...@@ -1303,7 +1302,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in list(outputs.items()): for key, varlist in six.moves.iteritems(outputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册