提交 7f794ea5 编写于 作者: M minqiyang

Replace the overfix of 2to3 with six.string_types

上级 ce4eba3b
......@@ -13,6 +13,7 @@
# limitations under the License.
import copy
import six
import functools
from . import layers
......@@ -246,8 +247,8 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
"""
def __init__(self, clip_norm, group_name="default_group"):
if not isinstance(group_name, str):
raise TypeError("'group_name' must be a basestring.")
if not isinstance(group_name, six.string_types):
raise TypeError("'group_name' must be a %s." % (six.string_types))
self.clip_norm = clip_norm
self.group_name = group_name
......@@ -312,7 +313,7 @@ def set_gradient_clip(clip, param_list=None, program=None):
program = framework.default_main_program()
if param_list is None:
param_list = program.block(0).all_parameters()
if all(isinstance(elem, str) for elem in param_list):
if all(isinstance(elem, six.string_types) for elem in param_list):
param_list = [program.block(0).var(elem) for elem in param_list]
if not all(isinstance(elem, framework.Parameter) for elem in param_list):
raise TypeError(
......
......@@ -15,7 +15,8 @@
from . import core
import numpy
import os
import six.moves as six
import six
from six.moves import zip, range, xrange
import multiprocessing
from .framework import Variable, default_main_program
......@@ -52,7 +53,7 @@ class DataToLoDTensorConverter(object):
self.data = []
self.lod = []
for i in six.range(lod_level):
for i in six.moves.range(lod_level):
self.lod.append([])
def feed(self, data):
......@@ -141,7 +142,7 @@ class DataFeeder(object):
if program is None:
program = default_main_program()
for each_var in feed_list:
if isinstance(each_var, str):
if isinstance(each_var, six.string_types):
each_var = program.block(0).var(each_var)
if not isinstance(each_var, Variable):
raise TypeError("Feed list should contain a list of variable")
......@@ -173,7 +174,7 @@ class DataFeeder(object):
dict: the result of conversion.
"""
converter = []
for lod_level, shape, dtype in six.zip(
for lod_level, shape, dtype in six.moves.zip(
self.feed_lod_level, self.feed_shapes, self.feed_dtypes):
converter.append(
DataToLoDTensorConverter(
......@@ -186,10 +187,12 @@ class DataFeeder(object):
assert len(each_sample) == len(converter), (
"The number of fields in data (%s) does not match " +
"len(feed_list) (%s)") % (len(each_sample), len(converter))
for each_converter, each_slot in six.zip(converter, each_sample):
for each_converter, each_slot in six.moves.zip(converter,
each_sample):
each_converter.feed(each_slot)
ret_dict = {}
for each_name, each_converter in six.zip(self.feed_names, converter):
for each_name, each_converter in six.moves.zip(self.feed_names,
converter):
ret_dict[each_name] = each_converter.done()
return ret_dict
......@@ -211,12 +214,14 @@ class DataFeeder(object):
if isinstance(self.place, core.CUDAPlace):
places = [
core.CUDAPlace(i)
for i in six.xrange(self._get_number_of_places_(num_places))
for i in six.moves.xrange(
self._get_number_of_places_(num_places))
]
else:
places = [
core.CPUPlace()
for _ in six.xrange(self._get_number_of_places_(num_places))
for _ in six.moves.xrange(
self._get_number_of_places_(num_places))
]
if len(iterable) != len(places):
......@@ -226,7 +231,7 @@ class DataFeeder(object):
"must be same.")
place = self.place
for p, batch in six.zip(places, iterable):
for p, batch in six.moves.zip(places, iterable):
self.place = p
yield self.feed(batch)
self.place = place
......
......@@ -14,6 +14,7 @@
import numpy as np
import contextlib
import six
from .framework import Program, default_main_program, Variable
from . import core
......@@ -211,7 +212,7 @@ def _get_program_cache_key(feed, fetch_list):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, str):
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
......
......@@ -524,12 +524,12 @@ class Operator(object):
% (in_proto.name, len(in_args)))
in_arg_names = []
for arg in in_args:
if issubclass(arg.__class__, six.string_types):
if isinstance(arg, six.string_types):
in_arg_names.append(arg)
elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode())
else:
if issubclass(arg.name.__class__, six.string_types):
if isinstance(arg.name, six.string_types):
in_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type):
in_arg_names.append(arg.name.decode())
......@@ -561,7 +561,7 @@ class Operator(object):
(out_proto.name, len(out_args)))
out_arg_names = []
for arg in out_args:
if issubclass(arg.name.__class__, six.string_types):
if isinstance(arg.name, six.string_types):
out_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type):
out_arg_names.append(arg.name.decode())
......@@ -911,7 +911,7 @@ class Block(object):
Returns:
Variable: the Variable with the giving name.
"""
if not issubclass(name.__class__, six.string_types):
if not isinstance(name, six.string_types):
if not isinstance(name, six.binary_type):
raise TypeError(
"var require string as parameter, but get %s instead." %
......
......@@ -14,12 +14,13 @@
import os
import random
import six
import subprocess
import logging
def crepr(v):
if type(v) is str or type(v) is str:
if isinstance(v, six.string_types):
return '"%s"' % v
return str(v)
......
......@@ -612,9 +612,6 @@ def save_inference_model(dirname,
if not (all(
isinstance(name, six.text_type)
for name in feeded_var_names)):
import sys
print([type(name) for name in feeded_var_names])
sys.stdout.flush()
raise ValueError(
"'feed_var_names' should be a list of str.")
else:
......
......@@ -14,6 +14,7 @@
import copy
import itertools
import six
from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating
from . import unique_name
......@@ -398,7 +399,7 @@ class LayerHelper(object):
act = self.kwargs.get('act', None)
if act is None:
return input_var
if isinstance(act, str):
if isinstance(act, six.string_types):
act = {'type': act}
if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
......
......@@ -32,7 +32,7 @@ def get_all_op_protos():
def is_str(s):
return isinstance(s, str) or isinstance(s, str)
return isinstance(s, six.string_types)
class OpDescCreationMethod(object):
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import six
from .initializer import Initializer, Xavier, Constant
from .regularizer import WeightDecayRegularizer
......@@ -134,7 +136,7 @@ class ParamAttr(object):
return [ParamAttr._to_attr(a) for a in arg]
elif isinstance(arg, ParamAttr):
return arg
elif isinstance(arg, str) or isinstance(arg, str):
elif isinstance(arg, six.string_types):
return ParamAttr(name=arg)
elif isinstance(arg, Initializer):
return ParamAttr(initializer=arg)
......
......@@ -16,6 +16,7 @@ import numpy as np
import unittest
import time
import itertools
import six
import paddle.fluid as fluid
import paddle.fluid.core as core
......@@ -40,7 +41,8 @@ class BenchmarkSuite(OpTest):
expect_t = np.array(item_cpu_out)
actual = item_gpu_out
actual_t = np.array(item_gpu_out)
var_name = variable if isinstance(variable, str) else variable.name
var_name = variable if isinstance(
variable, six.string_types) else variable.name
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol),
......
......@@ -18,6 +18,7 @@ import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
import paddle.fluid.profiler as profiler
import numpy
import six
class BaseParallelForTest(unittest.TestCase):
......@@ -102,7 +103,7 @@ class BaseParallelForTest(unittest.TestCase):
Fetched numpy arrays.
"""
if isinstance(fetch, str):
if isinstance(fetch, six.string_types):
fetch = [fetch]
main = fluid.Program()
startup = fluid.Program()
......
......@@ -14,6 +14,7 @@
import collections
import contextlib
import six
import sys
__all__ = ['generate', 'switch', 'guard']
......@@ -67,7 +68,7 @@ def switch(new_generator=None):
@contextlib.contextmanager
def guard(new_generator=None):
if isinstance(new_generator, str):
if isinstance(new_generator, six.string_types):
new_generator = UniqueNameGenerator(new_generator)
old = switch(new_generator)
yield
......
......@@ -67,10 +67,11 @@ def recordio(paths, buf_size=100):
import recordio as rec
import paddle.reader.decorator as dec
import six
import six.moves.cPickle as pickle
def reader():
if isinstance(paths, str):
if isinstance(paths, six.string_types):
path = paths
else:
path = ",".join(paths)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册