未验证 提交 df0a22d9 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Fix Python Version compatibility of dict.iteritems (#26778)

* Fix Python Version compatibility

* add import six
上级 1f6df878
...@@ -296,7 +296,7 @@ def convert_to_input_spec(inputs, input_spec): ...@@ -296,7 +296,7 @@ def convert_to_input_spec(inputs, input_spec):
elif isinstance(input_spec, dict): elif isinstance(input_spec, dict):
input_with_spec = {} input_with_spec = {}
check_type_and_len(inputs, input_spec, True) check_type_and_len(inputs, input_spec, True)
for name, input in inputs.items(): for name, input in six.iteritems(inputs):
if name in input_spec: if name in input_spec:
input_with_spec[name] = convert_to_input_spec(input, input_with_spec[name] = convert_to_input_spec(input,
input_spec[name]) input_spec[name])
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import six
import copy import copy
from collections import defaultdict from collections import defaultdict
...@@ -230,7 +231,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -230,7 +231,7 @@ class NameVisitor(gast.NodeVisitor):
return False return False
def _update_name_ids(self, new_name_ids): def _update_name_ids(self, new_name_ids):
for name_id, ctxs in new_name_ids.items(): for name_id, ctxs in six.iteritems(new_name_ids):
self.name_ids[name_id] = ctxs + self.name_ids[name_id] self.name_ids[name_id] = ctxs + self.name_ids[name_id]
...@@ -250,7 +251,7 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load): ...@@ -250,7 +251,7 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
""" """
name_ids = [ name_ids = [
var_id for var_id, var_ctx in var_ids_dict.items() var_id for var_id, var_ctx in six.iteritems(var_ids_dict)
if isinstance(var_ctx[0], ctx) if isinstance(var_ctx[0], ctx)
] ]
if return_ids: if return_ids:
...@@ -341,7 +342,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ...@@ -341,7 +342,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
def _vars_with_store(ids_dict): def _vars_with_store(ids_dict):
vars = [] vars = []
for k, ctxs in ids_dict.items(): for k, ctxs in six.iteritems(ids_dict):
if _is_return_var(ctxs): if _is_return_var(ctxs):
vars.append(k) vars.append(k)
return vars return vars
...@@ -353,7 +354,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ...@@ -353,7 +354,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
def _vars_loaded_before_store(ids_dict): def _vars_loaded_before_store(ids_dict):
new_dict = defaultdict(list) new_dict = defaultdict(list)
for k, ctxs in ids_dict.items(): for k, ctxs in six.iteritems(ids_dict):
for ctx in ctxs: for ctx in ctxs:
if isinstance(ctx, gast.Load): if isinstance(ctx, gast.Load):
new_dict[k].append(ctx) new_dict[k].append(ctx)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import logging import logging
import six
from paddle.fluid import log_helper from paddle.fluid import log_helper
from paddle.fluid import framework, backward, core from paddle.fluid import framework, backward, core
...@@ -334,7 +335,7 @@ class PartialProgramLayer(layers.Layer): ...@@ -334,7 +335,7 @@ class PartialProgramLayer(layers.Layer):
param_and_buffer_names_set.add(var.name) param_and_buffer_names_set.add(var.name)
for block in main_program.blocks: for block in main_program.blocks:
for name, var in block.vars.items(): for name, var in six.iteritems(block.vars):
if isinstance(var, framework.Parameter): if isinstance(var, framework.Parameter):
if name not in param_and_buffer_names_set: if name not in param_and_buffer_names_set:
raise ValueError( raise ValueError(
......
...@@ -617,7 +617,7 @@ class ProgramCache(object): ...@@ -617,7 +617,7 @@ class ProgramCache(object):
return len(self._caches) return len(self._caches)
def concrete_programs(self): def concrete_programs(self):
return [cp for key, (cp, _) in self._caches.iteritems()] return [cp for key, (cp, _) in six.iteritems(self._caches)]
def synchronized(func): def synchronized(func):
......
...@@ -493,7 +493,7 @@ def recover_globals_attribute(src_obj, dst_obj): ...@@ -493,7 +493,7 @@ def recover_globals_attribute(src_obj, dst_obj):
src_globals = getattr(src_obj, attr_name, {}) src_globals = getattr(src_obj, attr_name, {})
dst_globals = getattr(dst_obj, attr_name, {}) dst_globals = getattr(dst_obj, attr_name, {})
for k, v in src_globals.items(): for k, v in six.iteritems(src_globals):
# ignore builtin attribute. # ignore builtin attribute.
if not (k.startswith('__') and k.endswith('__')): if not (k.startswith('__') and k.endswith('__')):
dst_globals[k] = v dst_globals[k] = v
......
...@@ -754,7 +754,7 @@ def save(layer, model_path, input_spec=None, configs=None): ...@@ -754,7 +754,7 @@ def save(layer, model_path, input_spec=None, configs=None):
# saved to inference program may not need by dygraph Layer, # saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name # we only record the state_dict variable's structured name
state_names_dict = dict() state_names_dict = dict()
for structured_name, var in layer.state_dict().items(): for structured_name, var in six.iteritems(layer.state_dict()):
state_names_dict[var.name] = structured_name state_names_dict[var.name] = structured_name
# 3. share parameters from Layer to scope & record var info # 3. share parameters from Layer to scope & record var info
......
...@@ -191,6 +191,7 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): ...@@ -191,6 +191,7 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
out_1 = foo(to_variable(x_data), to_variable(y_data)) out_1 = foo(to_variable(x_data), to_variable(y_data))
self.assertTrue(np.allclose(x_data + y_data, out_1.numpy())) self.assertTrue(np.allclose(x_data + y_data, out_1.numpy()))
self.assertTrue(len(foo.program_cache) == 1) self.assertTrue(len(foo.program_cache) == 1)
self.assertTrue(len(foo.program_cache.concrete_programs()) == 1)
# [16, 10] + [10] (numpy) # [16, 10] + [10] (numpy)
out_2 = foo(to_variable(x_data), y_data) out_2 = foo(to_variable(x_data), y_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册