未验证 提交 df0a22d9 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Fix Python Version compatibility of dict.iteritems (#26778)

* Fix Python Version compatibility

* add import six
上级 1f6df878
......@@ -296,7 +296,7 @@ def convert_to_input_spec(inputs, input_spec):
elif isinstance(input_spec, dict):
input_with_spec = {}
check_type_and_len(inputs, input_spec, True)
for name, input in inputs.items():
for name, input in six.iteritems(inputs):
if name in input_spec:
input_with_spec[name] = convert_to_input_spec(input,
input_spec[name])
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import six
import copy
from collections import defaultdict
......@@ -230,7 +231,7 @@ class NameVisitor(gast.NodeVisitor):
return False
def _update_name_ids(self, new_name_ids):
for name_id, ctxs in new_name_ids.items():
for name_id, ctxs in six.iteritems(new_name_ids):
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
......@@ -250,7 +251,7 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
"""
name_ids = [
var_id for var_id, var_ctx in var_ids_dict.items()
var_id for var_id, var_ctx in six.iteritems(var_ids_dict)
if isinstance(var_ctx[0], ctx)
]
if return_ids:
......@@ -341,7 +342,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
def _vars_with_store(ids_dict):
vars = []
for k, ctxs in ids_dict.items():
for k, ctxs in six.iteritems(ids_dict):
if _is_return_var(ctxs):
vars.append(k)
return vars
......@@ -353,7 +354,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
def _vars_loaded_before_store(ids_dict):
new_dict = defaultdict(list)
for k, ctxs in ids_dict.items():
for k, ctxs in six.iteritems(ids_dict):
for ctx in ctxs:
if isinstance(ctx, gast.Load):
new_dict[k].append(ctx)
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import numpy as np
import logging
import six
from paddle.fluid import log_helper
from paddle.fluid import framework, backward, core
......@@ -334,7 +335,7 @@ class PartialProgramLayer(layers.Layer):
param_and_buffer_names_set.add(var.name)
for block in main_program.blocks:
for name, var in block.vars.items():
for name, var in six.iteritems(block.vars):
if isinstance(var, framework.Parameter):
if name not in param_and_buffer_names_set:
raise ValueError(
......
......@@ -617,7 +617,7 @@ class ProgramCache(object):
return len(self._caches)
def concrete_programs(self):
return [cp for key, (cp, _) in self._caches.iteritems()]
return [cp for key, (cp, _) in six.iteritems(self._caches)]
def synchronized(func):
......
......@@ -493,7 +493,7 @@ def recover_globals_attribute(src_obj, dst_obj):
src_globals = getattr(src_obj, attr_name, {})
dst_globals = getattr(dst_obj, attr_name, {})
for k, v in src_globals.items():
for k, v in six.iteritems(src_globals):
# ignore builtin attribute.
if not (k.startswith('__') and k.endswith('__')):
dst_globals[k] = v
......
......@@ -754,7 +754,7 @@ def save(layer, model_path, input_spec=None, configs=None):
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
state_names_dict = dict()
for structured_name, var in layer.state_dict().items():
for structured_name, var in six.iteritems(layer.state_dict()):
state_names_dict[var.name] = structured_name
# 3. share parameters from Layer to scope & record var info
......
......@@ -191,6 +191,7 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
out_1 = foo(to_variable(x_data), to_variable(y_data))
self.assertTrue(np.allclose(x_data + y_data, out_1.numpy()))
self.assertTrue(len(foo.program_cache) == 1)
self.assertTrue(len(foo.program_cache.concrete_programs()) == 1)
# [16, 10] + [10] (numpy)
out_2 = foo(to_variable(x_data), y_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册