提交 6dc07e7f 编写于 作者: M minqiyang

Replace items() with six.moves.iteritems() to improve memory usage

上级 9cd59990
...@@ -47,8 +47,9 @@ def tokenize(pattern): ...@@ -47,8 +47,9 @@ def tokenize(pattern):
while tf != None: while tf != None:
if bool(pattern.match(tf.name)): if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization. # newline and punctuations removal and ad-hoc tokenization.
yield tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate( yield tarf.extractfile(tf).read().rstrip(six.b(
None, six.b(string.punctuation)).lower().split() "\n\r")).translate(
None, six.b(string.punctuation)).lower().split()
tf = tarf.next() tf = tarf.next()
...@@ -63,7 +64,7 @@ def build_dict(pattern, cutoff): ...@@ -63,7 +64,7 @@ def build_dict(pattern, cutoff):
word_freq[word] += 1 word_freq[word] += 1
# Not sure if we should prune less-frequent words here. # Not sure if we should prune less-frequent words here.
word_freq = [x for x in list(word_freq.items()) if x[1] > cutoff] word_freq = [x for x in six.moves.iteritems(word_freq) if x[1] > cutoff]
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary)) words, _ = list(zip(*dictionary))
......
...@@ -21,7 +21,7 @@ into paddle reader creators. ...@@ -21,7 +21,7 @@ into paddle reader creators.
import paddle.dataset.common import paddle.dataset.common
import collections import collections
import tarfile import tarfile
from six.moves import range import six
__all__ = ['train', 'test', 'build_dict', 'convert'] __all__ = ['train', 'test', 'build_dict', 'convert']
...@@ -65,11 +65,13 @@ def build_dict(min_word_freq=50): ...@@ -65,11 +65,13 @@ def build_dict(min_word_freq=50):
# remove <unk> for now, since we will set it as last index # remove <unk> for now, since we will set it as last index
del word_freq['<unk>'] del word_freq['<unk>']
word_freq = [x for x in list(word_freq.items()) if x[1] > min_word_freq] word_freq = [
x for x in six.moves.iteritems(word_freq) if x[1] > min_word_freq
]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted)) words, _ = list(zip(*word_freq_sorted))
word_idx = dict(list(zip(words, range(len(words))))) word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words) word_idx['<unk>'] = len(words)
return word_idx return word_idx
...@@ -90,7 +92,7 @@ def reader_creator(filename, word_idx, n, data_type): ...@@ -90,7 +92,7 @@ def reader_creator(filename, word_idx, n, data_type):
l = ['<s>'] + l.strip().split() + ['<e>'] l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n: if len(l) >= n:
l = [word_idx.get(w, UNK) for w in l] l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1): for i in six.moves.range(n, len(l) + 1):
yield tuple(l[i - n:i]) yield tuple(l[i - n:i])
elif DataType.SEQ == data_type: elif DataType.SEQ == data_type:
l = l.strip().split() l = l.strip().split()
......
...@@ -20,6 +20,7 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK ...@@ -20,6 +20,7 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK
TODO(yuyang18): Complete dataset. TODO(yuyang18): Complete dataset.
""" """
import six
import collections import collections
from itertools import chain from itertools import chain
...@@ -64,7 +65,7 @@ def get_word_dict(): ...@@ -64,7 +65,7 @@ def get_word_dict():
for field in movie_reviews.fileids(category): for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field): for words in movie_reviews.words(field):
word_freq_dict[words] += 1 word_freq_dict[words] += 1
words_sort_list = list(word_freq_dict.items()) words_sort_list = six.moves.iteritems(word_freq_dict)
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list): for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index)) words_freq_sorted.append((word[0], index))
......
...@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True): ...@@ -156,8 +156,8 @@ def get_dict(dict_size, reverse=True):
tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse: if reverse:
src_dict = {v: k for k, v in list(src_dict.items())} src_dict = {v: k for k, v in six.moves.iteritems(src_dict)}
trg_dict = {v: k for k, v in list(trg_dict.items())} trg_dict = {v: k for k, v in six.moves.iteritems(trg_dict)}
return src_dict, trg_dict return src_dict, trg_dict
......
...@@ -72,7 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -72,7 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))
for idx, word in enumerate( for idx, word in enumerate(
sorted( sorted(
iter(list(word_dict.items())), six.moves.iteritems(word_dict),
key=lambda x: x[1], key=lambda x: x[1],
reverse=True)): reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
......
...@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -46,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
""" """
op_desc = core.OpDesc() op_desc = core.OpDesc()
op_desc.set_type(op_type) op_desc.set_type(op_type)
for para, args in list(inputs.items()): for para, args in six.moves.iteritems(inputs):
op_desc.set_input( op_desc.set_input(
para, para,
list( list(
map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg, map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg,
args))) args)))
for para, args in list(outputs.items()): for para, args in six.moves.iteritems(outputs):
op_desc.set_output( op_desc.set_output(
para, para,
list( list(
...@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -64,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
if op_role_attr_name not in attrs: if op_role_attr_name not in attrs:
attrs[ attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in list(attrs.items()): for name, val in six.moves.iteritems(attrs):
if isinstance(val, framework.Block): if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc) op_desc.set_block_attr(name, val.desc)
else: else:
...@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
op_desc.set_output(param_name, arg_names) op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
for var_name, inputs in list(renamed_vars.items()): for var_name, inputs in six.moves.iteritems(renamed_vars):
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append( pending_sum_ops.append(
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]}, (_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]},
...@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): ...@@ -445,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc.rename_output(name, new_name) op_desc.rename_output(name, new_name)
var_map[name] = new_name var_map[name] = new_name
for g, ng in list(var_map.items()): for g, ng in six.moves.iteritems(var_map):
if g in grad_to_var: if g in grad_to_var:
grad_to_var[ng] = grad_to_var[g] grad_to_var[ng] = grad_to_var[g]
grad_to_var.pop(g) grad_to_var.pop(g)
......
...@@ -958,7 +958,7 @@ class Block(object): ...@@ -958,7 +958,7 @@ class Block(object):
return list(self.iter_parameters()) return list(self.iter_parameters())
def iter_parameters(self): def iter_parameters(self):
return (item[1] for item in list(self.vars.items()) return (item[1] for item in six.moves.iteritems(self.vars)
if isinstance(item[1], Parameter)) if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
......
...@@ -106,7 +106,7 @@ class Graph(object): ...@@ -106,7 +106,7 @@ class Graph(object):
def _rank_repr(self): def _rank_repr(self):
ranks = sorted( ranks = sorted(
list(self.rank_groups.items()), six.moves.iteritems(self.rank_groups),
key=functools.cmp_to_key( key=functools.cmp_to_key(
lambda a, b: a[1].priority > b[1].priority)) lambda a, b: a[1].priority > b[1].priority))
repr = [] repr = []
...@@ -150,8 +150,9 @@ class Node(object): ...@@ -150,8 +150,9 @@ class Node(object):
reprs = '{name} [label={label} {extra} ];'.format( reprs = '{name} [label={label} {extra} ];'.format(
name=self.name, name=self.name,
label=self.label, label=self.label,
extra=',' + ','.join("%s=%s" % (key, crepr(value)) extra=',' + ','.join(
for key, value in list(self.attrs.items())) "%s=%s" % (key, crepr(value))
for key, value in six.moves.iteritems(self.attrs))
if self.attrs else "") if self.attrs else "")
return reprs return reprs
...@@ -175,7 +176,7 @@ class Edge(object): ...@@ -175,7 +176,7 @@ class Edge(object):
target=self.target.name, target=self.target.name,
extra="" if not self.attrs else extra="" if not self.attrs else
"[" + ','.join("{}={}".format(attr[0], crepr(attr[1])) "[" + ','.join("{}={}".format(attr[0], crepr(attr[1]))
for attr in list(self.attrs.items())) + "]") for attr in six.moves.iteritems(self.attrs)) + "]")
return repr return repr
......
...@@ -22,6 +22,7 @@ from ..initializer import force_init_on_cpu ...@@ -22,6 +22,7 @@ from ..initializer import force_init_on_cpu
from .ops import logical_and, logical_not, logical_or from .ops import logical_and, logical_not, logical_or
import numpy import numpy
import warnings import warnings
import six
from functools import reduce from functools import reduce
__all__ = [ __all__ = [
...@@ -602,7 +603,7 @@ class StaticRNN(object): ...@@ -602,7 +603,7 @@ class StaticRNN(object):
boot_memories = [] boot_memories = []
pre_memories = [] pre_memories = []
memories = [] memories = []
for _, mem in list(self.memories.items()): for _, mem in six.moves.iteritems(self.memories):
boot_memories.append(mem.init) boot_memories.append(mem.init)
pre_memories.append(mem.pre_mem.name) pre_memories.append(mem.pre_mem.name)
mem_var = rnn_block.var(mem.mem.name) mem_var = rnn_block.var(mem.mem.name)
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
""" """
Fluid Metrics Fluid Metrics
The metrics are accomplished via Python natively. The metrics are accomplished via Python natively.
""" """
import numpy as np import numpy as np
import copy import copy
import warnings import warnings
import six
__all__ = [ __all__ = [
'MetricBase', 'MetricBase',
...@@ -79,10 +80,10 @@ class MetricBase(object): ...@@ -79,10 +80,10 @@ class MetricBase(object):
""" """
states = { states = {
attr: value attr: value
for attr, value in list(self.__dict__.items()) for attr, value in six.moves.iteritems(self.__dict__)
if not attr.startswith("_") if not attr.startswith("_")
} }
for attr, value in list(states.items()): for attr, value in six.moves.iteritems(states):
if isinstance(value, int): if isinstance(value, int):
setattr(self, attr, 0) setattr(self, attr, 0)
elif isinstance(value, float): elif isinstance(value, float):
...@@ -105,7 +106,7 @@ class MetricBase(object): ...@@ -105,7 +106,7 @@ class MetricBase(object):
""" """
states = { states = {
attr: value attr: value
for attr, value in list(self.__dict__.items()) for attr, value in six.moves.iteritems(self.__dict__)
if not attr.startswith("_") if not attr.startswith("_")
} }
config = {} config = {}
...@@ -141,10 +142,10 @@ class CompositeMetric(MetricBase): ...@@ -141,10 +142,10 @@ class CompositeMetric(MetricBase):
""" """
Composite multiple metrics in one instance. Composite multiple metrics in one instance.
for example, merge F1, accuracy, recall into one Metric. for example, merge F1, accuracy, recall into one Metric.
Examples: Examples:
.. code-block:: python .. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32") labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32") data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh") pred = fluid.layers.fc(input=data, size=1000, act="tanh")
......
...@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest): ...@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest):
def _get_input_names(self): def _get_input_names(self):
inputs = [] inputs = []
for name, value in list(self.inputs.items()): for name, value in six.moves.iteritems(self.inputs):
if isinstance(value, list): if isinstance(value, list):
inputs.extend([sub_name for sub_name, _ in value]) inputs.extend([sub_name for sub_name, _ in value])
inputs.append(name) inputs.append(name)
...@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest): ...@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest):
def _get_output_names(self): def _get_output_names(self):
outputs = [] outputs = []
for var_name, var in list(self.outputs.items()): for var_name, var in six.moves.iteritems(self.outputs):
if isinstance(var, list): if isinstance(var, list):
for sub_var_name, sub_var in var: for sub_var_name, sub_var in var:
outputs.append(sub_var_name) outputs.append(sub_var_name)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import unittest import unittest
import numpy as np import numpy as np
import six
import sys import sys
import collections import collections
import math import math
...@@ -176,7 +177,7 @@ class TestDetectionMAPOp(OpTest): ...@@ -176,7 +177,7 @@ class TestDetectionMAPOp(OpTest):
true_pos[label].append([score, tp]) true_pos[label].append([score, tp])
false_pos[label].append([score, fp]) false_pos[label].append([score, fp])
for (label, label_pos_num) in list(label_count.items()): for (label, label_pos_num) in six.moves.iteritems(label_count):
if label_pos_num == 0 or label not in true_pos: continue if label_pos_num == 0 or label not in true_pos: continue
label_true_pos = true_pos[label] label_true_pos = true_pos[label]
label_false_pos = false_pos[label] label_false_pos = false_pos[label]
......
...@@ -18,6 +18,7 @@ from paddle.fluid.executor import Executor ...@@ -18,6 +18,7 @@ from paddle.fluid.executor import Executor
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy import numpy
import unittest import unittest
import six
class TestLoDRankTable(unittest.TestCase): class TestLoDRankTable(unittest.TestCase):
...@@ -36,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase): ...@@ -36,7 +37,7 @@ class TestLoDRankTable(unittest.TestCase):
exe.run(scope=scope, feed={'x': tensor}) exe.run(scope=scope, feed={'x': tensor})
var = scope.find_var(rank_table.name) var = scope.find_var(rank_table.name)
table = var.get_lod_rank_table() table = var.get_lod_rank_table()
self.assertEqual([(0, 5), (1, 1), (2, 1)], list(table.items())) self.assertEqual([(0, 5), (1, 1), (2, 1)], six.moves.iteritems(table))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import itertools import itertools
import numpy as np import numpy as np
import six
from op_test import OpTest from op_test import OpTest
...@@ -32,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None): ...@@ -32,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None):
# accumulate statistics # accumulate statistics
pos, neg, neu = 0, 0, 0 pos, neg, neu = 0, 0, 0
for _, ranks in list(predictions.items()): for _, ranks in six.moves.iteritems(predictions):
for e1, e2 in itertools.combinations(ranks, 2): for e1, e2 in itertools.combinations(ranks, 2):
s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2] s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
w = (w1 + w2) * 0.5 w = (w1 + w2) * 0.5
......
...@@ -16,6 +16,7 @@ import contextlib ...@@ -16,6 +16,7 @@ import contextlib
import os import os
import errno import errno
import shutil import shutil
import six
import time import time
from . import core from . import core
...@@ -618,7 +619,7 @@ def build_feed_var_list(program, feed_order): ...@@ -618,7 +619,7 @@ def build_feed_var_list(program, feed_order):
"The values of 'feed_order' should be a permutation of [0, len(feed_order))" "The values of 'feed_order' should be a permutation of [0, len(feed_order))"
) )
sorted_pair_list = sorted( sorted_pair_list = sorted(
list(feed_order.items()), key=lambda item: item[1]) six.moves.iteritems(feed_order), key=lambda item: item[1])
feed_var_list = [ feed_var_list = [
program.global_block().var(pair[0]) for pair in sorted_pair_list program.global_block().var(pair[0]) for pair in sorted_pair_list
] ]
...@@ -1036,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args): ...@@ -1036,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
cur_dir = _get_trainer_dir(dirname, trainer_id) cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in list(trainer_args.items()): for name, value in six.moves.iteritems(trainer_args):
args_file = os.path.join(cur_dir, name) args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f: with open(args_file, 'w') as f:
f.write(str(value)) f.write(str(value))
......
...@@ -218,7 +218,8 @@ class DistributeTranspiler(object): ...@@ -218,7 +218,8 @@ class DistributeTranspiler(object):
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1 # fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above # shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = list(self.grad_var_mapping.items()) grad_var_mapping_items = list(
six.moves.iteritems(self.grad_var_mapping))
if not self.config.slice_var_up: if not self.config.slice_var_up:
random.seed(self.origin_program.random_seed) random.seed(self.origin_program.random_seed)
...@@ -279,7 +280,7 @@ class DistributeTranspiler(object): ...@@ -279,7 +280,7 @@ class DistributeTranspiler(object):
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in list(self.param_var_mapping.items()): for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
eps = [] eps = []
for var in splited_var: for var in splited_var:
index = [v.name for v in recv_vars].index(var.name) index = [v.name for v in recv_vars].index(var.name)
...@@ -303,7 +304,7 @@ class DistributeTranspiler(object): ...@@ -303,7 +304,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
for varname, splited_var in list(self.param_var_mapping.items()): for varname, splited_var in six.moves.iteritems(self.param_var_mapping):
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
orig_param = program.global_block().vars[varname] orig_param = program.global_block().vars[varname]
...@@ -560,7 +561,7 @@ class DistributeTranspiler(object): ...@@ -560,7 +561,7 @@ class DistributeTranspiler(object):
# 1. create vars in pserver program to startup program # 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars pserver_vars = pserver_program.global_block().vars
created_var_map = collections.OrderedDict() created_var_map = collections.OrderedDict()
for _, var in list(pserver_vars.items()): for _, var in six.moves.iteritems(pserver_vars):
tmpvar = s_prog.global_block()._clone_variable(var) tmpvar = s_prog.global_block()._clone_variable(var)
created_var_map[var.name] = tmpvar created_var_map[var.name] = tmpvar
...@@ -997,7 +998,7 @@ class DistributeTranspiler(object): ...@@ -997,7 +998,7 @@ class DistributeTranspiler(object):
block_map[varname] = [] block_map[varname] = []
block_map[varname].append((int(offset), int(size))) block_map[varname].append((int(offset), int(size)))
for varname, splited in list(block_map.items()): for varname, splited in six.moves.iteritems(block_map):
orig_var = program.global_block().var(varname) orig_var = program.global_block().var(varname)
if len(splited) == 1: if len(splited) == 1:
if self.sync_mode and add_trainer_suffix: if self.sync_mode and add_trainer_suffix:
...@@ -1248,9 +1249,7 @@ class DistributeTranspiler(object): ...@@ -1248,9 +1249,7 @@ class DistributeTranspiler(object):
def _is_splited_grad_var(self, var, var_dict): def _is_splited_grad_var(self, var, var_dict):
grad_block = None grad_block = None
# TODO(minqiyang): replace these items() with six.iteritems() to for _, g in six.moves.iteritems(var_dict):
# improve memory
for _, g in list(var_dict.items()):
if self._orig_varname(g.name) == self._orig_varname(var.name): if self._orig_varname(g.name) == self._orig_varname(var.name):
if g.name.find(".trainer_") == -1: if g.name.find(".trainer_") == -1:
grad_block = g grad_block = g
...@@ -1260,7 +1259,7 @@ class DistributeTranspiler(object): ...@@ -1260,7 +1259,7 @@ class DistributeTranspiler(object):
def _clone_lr_op(self, program, block, op): def _clone_lr_op(self, program, block, op):
inputs = self._get_input_map_from_op( inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op) self.origin_program.global_block().vars, op)
for key, varlist in list(inputs.items()): for key, varlist in six.moves.iteritems(inputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
...@@ -1269,7 +1268,7 @@ class DistributeTranspiler(object): ...@@ -1269,7 +1268,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op) self.origin_program.global_block().vars, op)
for key, varlist in list(outputs.items()): for key, varlist in six.moves.iteritems(outputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
...@@ -1284,7 +1283,7 @@ class DistributeTranspiler(object): ...@@ -1284,7 +1283,7 @@ class DistributeTranspiler(object):
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op( inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in list(inputs.items()): for key, varlist in six.moves.iteritems(inputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
...@@ -1303,7 +1302,7 @@ class DistributeTranspiler(object): ...@@ -1303,7 +1302,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in list(outputs.items()): for key, varlist in six.moves.iteritems(outputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册