提交 7f794ea5 编写于 作者: M minqiyang

Replace the overfix of 2to3 with six.string_types

上级 ce4eba3b
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import six
import functools import functools
from . import layers from . import layers
...@@ -246,8 +247,8 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -246,8 +247,8 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
""" """
def __init__(self, clip_norm, group_name="default_group"): def __init__(self, clip_norm, group_name="default_group"):
if not isinstance(group_name, str): if not isinstance(group_name, six.string_types):
raise TypeError("'group_name' must be a basestring.") raise TypeError("'group_name' must be a %s." % (six.string_types))
self.clip_norm = clip_norm self.clip_norm = clip_norm
self.group_name = group_name self.group_name = group_name
...@@ -312,7 +313,7 @@ def set_gradient_clip(clip, param_list=None, program=None): ...@@ -312,7 +313,7 @@ def set_gradient_clip(clip, param_list=None, program=None):
program = framework.default_main_program() program = framework.default_main_program()
if param_list is None: if param_list is None:
param_list = program.block(0).all_parameters() param_list = program.block(0).all_parameters()
if all(isinstance(elem, str) for elem in param_list): if all(isinstance(elem, six.string_types) for elem in param_list):
param_list = [program.block(0).var(elem) for elem in param_list] param_list = [program.block(0).var(elem) for elem in param_list]
if not all(isinstance(elem, framework.Parameter) for elem in param_list): if not all(isinstance(elem, framework.Parameter) for elem in param_list):
raise TypeError( raise TypeError(
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
from . import core from . import core
import numpy import numpy
import os import os
import six.moves as six import six
from six.moves import zip, range, xrange
import multiprocessing import multiprocessing
from .framework import Variable, default_main_program from .framework import Variable, default_main_program
...@@ -52,7 +53,7 @@ class DataToLoDTensorConverter(object): ...@@ -52,7 +53,7 @@ class DataToLoDTensorConverter(object):
self.data = [] self.data = []
self.lod = [] self.lod = []
for i in six.range(lod_level): for i in six.moves.range(lod_level):
self.lod.append([]) self.lod.append([])
def feed(self, data): def feed(self, data):
...@@ -141,7 +142,7 @@ class DataFeeder(object): ...@@ -141,7 +142,7 @@ class DataFeeder(object):
if program is None: if program is None:
program = default_main_program() program = default_main_program()
for each_var in feed_list: for each_var in feed_list:
if isinstance(each_var, str): if isinstance(each_var, six.string_types):
each_var = program.block(0).var(each_var) each_var = program.block(0).var(each_var)
if not isinstance(each_var, Variable): if not isinstance(each_var, Variable):
raise TypeError("Feed list should contain a list of variable") raise TypeError("Feed list should contain a list of variable")
...@@ -173,7 +174,7 @@ class DataFeeder(object): ...@@ -173,7 +174,7 @@ class DataFeeder(object):
dict: the result of conversion. dict: the result of conversion.
""" """
converter = [] converter = []
for lod_level, shape, dtype in six.zip( for lod_level, shape, dtype in six.moves.zip(
self.feed_lod_level, self.feed_shapes, self.feed_dtypes): self.feed_lod_level, self.feed_shapes, self.feed_dtypes):
converter.append( converter.append(
DataToLoDTensorConverter( DataToLoDTensorConverter(
...@@ -186,10 +187,12 @@ class DataFeeder(object): ...@@ -186,10 +187,12 @@ class DataFeeder(object):
assert len(each_sample) == len(converter), ( assert len(each_sample) == len(converter), (
"The number of fields in data (%s) does not match " + "The number of fields in data (%s) does not match " +
"len(feed_list) (%s)") % (len(each_sample), len(converter)) "len(feed_list) (%s)") % (len(each_sample), len(converter))
for each_converter, each_slot in six.zip(converter, each_sample): for each_converter, each_slot in six.moves.zip(converter,
each_sample):
each_converter.feed(each_slot) each_converter.feed(each_slot)
ret_dict = {} ret_dict = {}
for each_name, each_converter in six.zip(self.feed_names, converter): for each_name, each_converter in six.moves.zip(self.feed_names,
converter):
ret_dict[each_name] = each_converter.done() ret_dict[each_name] = each_converter.done()
return ret_dict return ret_dict
...@@ -211,12 +214,14 @@ class DataFeeder(object): ...@@ -211,12 +214,14 @@ class DataFeeder(object):
if isinstance(self.place, core.CUDAPlace): if isinstance(self.place, core.CUDAPlace):
places = [ places = [
core.CUDAPlace(i) core.CUDAPlace(i)
for i in six.xrange(self._get_number_of_places_(num_places)) for i in six.moves.xrange(
self._get_number_of_places_(num_places))
] ]
else: else:
places = [ places = [
core.CPUPlace() core.CPUPlace()
for _ in six.xrange(self._get_number_of_places_(num_places)) for _ in six.moves.xrange(
self._get_number_of_places_(num_places))
] ]
if len(iterable) != len(places): if len(iterable) != len(places):
...@@ -226,7 +231,7 @@ class DataFeeder(object): ...@@ -226,7 +231,7 @@ class DataFeeder(object):
"must be same.") "must be same.")
place = self.place place = self.place
for p, batch in six.zip(places, iterable): for p, batch in six.moves.zip(places, iterable):
self.place = p self.place = p
yield self.feed(batch) yield self.feed(batch)
self.place = place self.place = place
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import numpy as np import numpy as np
import contextlib import contextlib
import six
from .framework import Program, default_main_program, Variable from .framework import Program, default_main_program, Variable
from . import core from . import core
...@@ -211,7 +212,7 @@ def _get_program_cache_key(feed, fetch_list): ...@@ -211,7 +212,7 @@ def _get_program_cache_key(feed, fetch_list):
return var.desc.name() return var.desc.name()
elif isinstance(var, str): elif isinstance(var, str):
return var return var
elif isinstance(var, str): elif isinstance(var, six.string_types):
return str(var) return str(var)
else: else:
raise TypeError(str(var) + " should be Variable or str") raise TypeError(str(var) + " should be Variable or str")
...@@ -229,8 +230,8 @@ class Executor(object): ...@@ -229,8 +230,8 @@ class Executor(object):
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run. Note: the executor will run all the variables(or names) that user want to get after program run. Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list. operators in the program but not only the operators dependent by the fetch_list.
It store the global variables into the global scope, and create a local scope for the temporary It store the global variables into the global scope, and create a local scope for the temporary
variables. The local scope contents will be discarded after every minibatch forward/backward finished. variables. The local scope contents will be discarded after every minibatch forward/backward finished.
But the global scope variables will be persistent through different runs. But the global scope variables will be persistent through different runs.
All of ops in program will be running in sequence. All of ops in program will be running in sequence.
......
...@@ -524,12 +524,12 @@ class Operator(object): ...@@ -524,12 +524,12 @@ class Operator(object):
% (in_proto.name, len(in_args))) % (in_proto.name, len(in_args)))
in_arg_names = [] in_arg_names = []
for arg in in_args: for arg in in_args:
if issubclass(arg.__class__, six.string_types): if isinstance(arg, six.string_types):
in_arg_names.append(arg) in_arg_names.append(arg)
elif isinstance(arg, six.binary_type): elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode()) in_arg_names.append(arg.decode())
else: else:
if issubclass(arg.name.__class__, six.string_types): if isinstance(arg.name, six.string_types):
in_arg_names.append(arg.name) in_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type): elif isinstance(arg.name, six.binary_type):
in_arg_names.append(arg.name.decode()) in_arg_names.append(arg.name.decode())
...@@ -561,7 +561,7 @@ class Operator(object): ...@@ -561,7 +561,7 @@ class Operator(object):
(out_proto.name, len(out_args))) (out_proto.name, len(out_args)))
out_arg_names = [] out_arg_names = []
for arg in out_args: for arg in out_args:
if issubclass(arg.name.__class__, six.string_types): if isinstance(arg.name, six.string_types):
out_arg_names.append(arg.name) out_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type): elif isinstance(arg.name, six.binary_type):
out_arg_names.append(arg.name.decode()) out_arg_names.append(arg.name.decode())
...@@ -911,7 +911,7 @@ class Block(object): ...@@ -911,7 +911,7 @@ class Block(object):
Returns: Returns:
Variable: the Variable with the giving name. Variable: the Variable with the giving name.
""" """
if not issubclass(name.__class__, six.string_types): if not isinstance(name, six.string_types):
if not isinstance(name, six.binary_type): if not isinstance(name, six.binary_type):
raise TypeError( raise TypeError(
"var require string as parameter, but get %s instead." % "var require string as parameter, but get %s instead." %
......
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
import os import os
import random import random
import six
import subprocess import subprocess
import logging import logging
def crepr(v): def crepr(v):
if type(v) is str or type(v) is str: if isinstance(v, six.string_types):
return '"%s"' % v return '"%s"' % v
return str(v) return str(v)
......
...@@ -612,9 +612,6 @@ def save_inference_model(dirname, ...@@ -612,9 +612,6 @@ def save_inference_model(dirname,
if not (all( if not (all(
isinstance(name, six.text_type) isinstance(name, six.text_type)
for name in feeded_var_names)): for name in feeded_var_names)):
import sys
print([type(name) for name in feeded_var_names])
sys.stdout.flush()
raise ValueError( raise ValueError(
"'feed_var_names' should be a list of str.") "'feed_var_names' should be a list of str.")
else: else:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import copy import copy
import itertools import itertools
import six
from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating
from . import unique_name from . import unique_name
...@@ -398,7 +399,7 @@ class LayerHelper(object): ...@@ -398,7 +399,7 @@ class LayerHelper(object):
act = self.kwargs.get('act', None) act = self.kwargs.get('act', None)
if act is None: if act is None:
return input_var return input_var
if isinstance(act, str): if isinstance(act, six.string_types):
act = {'type': act} act = {'type': act}
if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'): if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
......
...@@ -32,7 +32,7 @@ def get_all_op_protos(): ...@@ -32,7 +32,7 @@ def get_all_op_protos():
def is_str(s): def is_str(s):
return isinstance(s, str) or isinstance(s, str) return isinstance(s, six.string_types)
class OpDescCreationMethod(object): class OpDescCreationMethod(object):
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
from .initializer import Initializer, Xavier, Constant from .initializer import Initializer, Xavier, Constant
from .regularizer import WeightDecayRegularizer from .regularizer import WeightDecayRegularizer
...@@ -134,7 +136,7 @@ class ParamAttr(object): ...@@ -134,7 +136,7 @@ class ParamAttr(object):
return [ParamAttr._to_attr(a) for a in arg] return [ParamAttr._to_attr(a) for a in arg]
elif isinstance(arg, ParamAttr): elif isinstance(arg, ParamAttr):
return arg return arg
elif isinstance(arg, str) or isinstance(arg, str): elif isinstance(arg, six.string_types):
return ParamAttr(name=arg) return ParamAttr(name=arg)
elif isinstance(arg, Initializer): elif isinstance(arg, Initializer):
return ParamAttr(initializer=arg) return ParamAttr(initializer=arg)
......
...@@ -16,6 +16,7 @@ import numpy as np ...@@ -16,6 +16,7 @@ import numpy as np
import unittest import unittest
import time import time
import itertools import itertools
import six
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -40,7 +41,8 @@ class BenchmarkSuite(OpTest): ...@@ -40,7 +41,8 @@ class BenchmarkSuite(OpTest):
expect_t = np.array(item_cpu_out) expect_t = np.array(item_cpu_out)
actual = item_gpu_out actual = item_gpu_out
actual_t = np.array(item_gpu_out) actual_t = np.array(item_gpu_out)
var_name = variable if isinstance(variable, str) else variable.name var_name = variable if isinstance(
variable, six.string_types) else variable.name
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual_t, expect_t, atol=atol), actual_t, expect_t, atol=atol),
......
...@@ -18,6 +18,7 @@ import paddle.fluid as fluid ...@@ -18,6 +18,7 @@ import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places from paddle.fluid.layers.device import get_places
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
import numpy import numpy
import six
class BaseParallelForTest(unittest.TestCase): class BaseParallelForTest(unittest.TestCase):
...@@ -25,20 +26,20 @@ class BaseParallelForTest(unittest.TestCase): ...@@ -25,20 +26,20 @@ class BaseParallelForTest(unittest.TestCase):
""" """
Run the unittest for parallel.for Run the unittest for parallel.for
Args: Args:
callback(callable): A callable function returns a generator. There callback(callable): A callable function returns a generator. There
are two yields in the generator function. The first yield are two yields in the generator function. The first yield
returns the data layers, and the second yield returns the loss. returns the data layers, and the second yield returns the loss.
The modified data variables will be sent back during the first The modified data variables will be sent back during the first
yield. yield.
feed(dict): The executor feeding dictionary. feed(dict): The executor feeding dictionary.
fetch(list|basestr): The fetch name lists. fetch(list|basestr): The fetch name lists.
Returns: Returns:
None None
Raises: Raises:
AssertionError when the computation of cpu, parallel.for in cpu, AssertionError when the computation of cpu, parallel.for in cpu,
gpu, parallel.for in gpu are different. gpu, parallel.for in gpu are different.
""" """
...@@ -95,14 +96,14 @@ class BaseParallelForTest(unittest.TestCase): ...@@ -95,14 +96,14 @@ class BaseParallelForTest(unittest.TestCase):
""" """
Run a single test, returns the fetch values Run a single test, returns the fetch values
Args: Args:
place(Place): the computation place. place(Place): the computation place.
use_parallel(bool): Whether use parallel.for or not. use_parallel(bool): Whether use parallel.for or not.
Returns: Returns:
Fetched numpy arrays. Fetched numpy arrays.
""" """
if isinstance(fetch, str): if isinstance(fetch, six.string_types):
fetch = [fetch] fetch = [fetch]
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
...@@ -156,7 +157,7 @@ class BaseParallelForTest(unittest.TestCase): ...@@ -156,7 +157,7 @@ class BaseParallelForTest(unittest.TestCase):
Returns: Returns:
None None
Raises: Raises:
AssertionError AssertionError
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import collections import collections
import contextlib import contextlib
import six
import sys import sys
__all__ = ['generate', 'switch', 'guard'] __all__ = ['generate', 'switch', 'guard']
...@@ -67,7 +68,7 @@ def switch(new_generator=None): ...@@ -67,7 +68,7 @@ def switch(new_generator=None):
@contextlib.contextmanager @contextlib.contextmanager
def guard(new_generator=None): def guard(new_generator=None):
if isinstance(new_generator, str): if isinstance(new_generator, six.string_types):
new_generator = UniqueNameGenerator(new_generator) new_generator = UniqueNameGenerator(new_generator)
old = switch(new_generator) old = switch(new_generator)
yield yield
......
...@@ -67,10 +67,11 @@ def recordio(paths, buf_size=100): ...@@ -67,10 +67,11 @@ def recordio(paths, buf_size=100):
import recordio as rec import recordio as rec
import paddle.reader.decorator as dec import paddle.reader.decorator as dec
import six
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
def reader(): def reader():
if isinstance(paths, str): if isinstance(paths, six.string_types):
path = paths path = paths
else: else:
path = ",".join(paths) path = ",".join(paths)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册