未验证 提交 0f7187af 编写于 作者: T tianshuo78520a 提交者: GitHub

Del six.PY code2 (#33607)

* del py2 code2

* fix test timeout
上级 79cbc8ea
...@@ -17,12 +17,8 @@ import math ...@@ -17,12 +17,8 @@ import math
__all__ = [] __all__ = []
if six.PY2: int_type = int
int_type = int long_type = int
long_type = long # noqa: F821
else:
int_type = int
long_type = int
# str and bytes related functions # str and bytes related functions
...@@ -262,7 +258,4 @@ def get_exception_message(exc): ...@@ -262,7 +258,4 @@ def get_exception_message(exc):
""" """
assert exc is not None assert exc is not None
if six.PY2: return str(exc)
return exc.message
else:
return str(exc)
...@@ -62,11 +62,7 @@ def reader_creator(filename, sub_name, cycle=False): ...@@ -62,11 +62,7 @@ def reader_creator(filename, sub_name, cycle=False):
if sub_name in each_item.name) if sub_name in each_item.name)
for name in names: for name in names:
if six.PY2: batch = pickle.load(f.extractfile(name), encoding='bytes')
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(
f.extractfile(name), encoding='bytes')
for item in read_batch(batch): for item in read_batch(batch):
yield item yield item
......
...@@ -101,8 +101,6 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -101,8 +101,6 @@ def download(url, module_name, md5sum, save_name=None):
bar = paddle.hapi.progressbar.ProgressBar( bar = paddle.hapi.progressbar.ProgressBar(
total_iter, name='item') total_iter, name='item')
for data in r.iter_content(chunk_size=chunk_size): for data in r.iter_content(chunk_size=chunk_size):
if six.PY2:
data = six.b(data)
f.write(data) f.write(data)
log_index += 1 log_index += 1
bar.update(log_index, {}) bar.update(log_index, {})
......
...@@ -132,10 +132,7 @@ def reader_creator(data_file, ...@@ -132,10 +132,7 @@ def reader_creator(data_file,
file = file.strip() file = file.strip()
batch = None batch = None
with open(file, 'rb') as f: with open(file, 'rb') as f:
if six.PY2: batch = pickle.load(f, encoding='bytes')
batch = pickle.load(f)
else:
batch = pickle.load(f, encoding='bytes')
if six.PY3: if six.PY3:
batch = cpt.to_text(batch) batch = cpt.to_text(batch)
......
...@@ -17,12 +17,8 @@ import logging ...@@ -17,12 +17,8 @@ import logging
import six import six
# NOTE: HTTPServer has a different name in python2 and python3 # NOTE: HTTPServer has a different name in python2 and python3
if six.PY2: from http.server import HTTPServer
from BaseHTTPServer import HTTPServer import http.server as SimpleHTTPServer
import SimpleHTTPServer
else:
from http.server import HTTPServer
import http.server as SimpleHTTPServer
import time import time
import threading import threading
......
...@@ -226,10 +226,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -226,10 +226,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
if iters == skip_batch_num: if iters == skip_batch_num:
total_samples = 0 total_samples = 0
infer_start_time = time.time() infer_start_time = time.time()
if six.PY2: images = list(map(lambda x: x[0].reshape(dshape), data))
images = map(lambda x: x[0].reshape(dshape), data)
if six.PY3:
images = list(map(lambda x: x[0].reshape(dshape), data))
images = np.array(images).astype('float32') images = np.array(images).astype('float32')
labels = np.array([x[1] for x in data]).astype('int64') labels = np.array([x[1] for x in data]).astype('int64')
......
...@@ -196,10 +196,7 @@ class QuantInt8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -196,10 +196,7 @@ class QuantInt8ImageClassificationComparisonTest(unittest.TestCase):
if iters == skip_batch_num: if iters == skip_batch_num:
total_samples = 0 total_samples = 0
infer_start_time = time.time() infer_start_time = time.time()
if six.PY2: images = list(map(lambda x: x[0].reshape(dshape), data))
images = map(lambda x: x[0].reshape(dshape), data)
if six.PY3:
images = list(map(lambda x: x[0].reshape(dshape), data))
images = np.array(images).astype('float32') images = np.array(images).astype('float32')
labels = np.array([x[1] for x in data]).astype('int64') labels = np.array([x[1] for x in data]).astype('int64')
......
...@@ -27,10 +27,7 @@ from collections import namedtuple ...@@ -27,10 +27,7 @@ from collections import namedtuple
from paddle.fluid.framework import _set_expected_place, _current_expected_place from paddle.fluid.framework import _set_expected_place, _current_expected_place
# NOTE: queue has a different name in python2 and python3 # NOTE: queue has a different name in python2 and python3
if six.PY2: import queue
import Queue as queue
else:
import queue
import paddle import paddle
from .. import core, layers from .. import core, layers
......
...@@ -26,10 +26,7 @@ from ..framework import in_dygraph_mode ...@@ -26,10 +26,7 @@ from ..framework import in_dygraph_mode
from .flat import _flatten_batch from .flat import _flatten_batch
# NOTE: queue has a different name in python2 and python3 # NOTE: queue has a different name in python2 and python3
if six.PY2: import queue
import Queue as queue
else:
import queue
__all__ = ['get_worker_info'] __all__ = ['get_worker_info']
......
...@@ -19,7 +19,6 @@ import collections ...@@ -19,7 +19,6 @@ import collections
import functools import functools
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
import pickle import pickle
import six
from . import learning_rate_scheduler from . import learning_rate_scheduler
import warnings import warnings
from .. import core from .. import core
...@@ -194,16 +193,14 @@ def load_dygraph(model_path, **configs): ...@@ -194,16 +193,14 @@ def load_dygraph(model_path, **configs):
para_dict = {} para_dict = {}
if os.path.exists(params_file_path): if os.path.exists(params_file_path):
with open(params_file_path, 'rb') as f: with open(params_file_path, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load( para_dict = pickle.load(f, encoding='latin1')
f, encoding='latin1')
if not config.keep_name_table and "StructuredToParameterName@@" in para_dict: if not config.keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"] del para_dict["StructuredToParameterName@@"]
if os.path.exists(opti_file_path): if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f: with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load( opti_dict = pickle.load(f, encoding='latin1')
f, encoding='latin1')
else: else:
# check model path # check model path
if not os.path.isdir(model_prefix): if not os.path.isdir(model_prefix):
......
...@@ -60,10 +60,7 @@ class BaseNodeVisitor(gast.NodeVisitor): ...@@ -60,10 +60,7 @@ class BaseNodeVisitor(gast.NodeVisitor):
# imp is deprecated in python3 # imp is deprecated in python3
if six.PY2: from importlib.machinery import SourceFileLoader
import imp
else:
from importlib.machinery import SourceFileLoader
dygraph_class_to_static_api = { dygraph_class_to_static_api = {
"CosineDecay": "cosine_decay", "CosineDecay": "cosine_decay",
...@@ -491,12 +488,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -491,12 +488,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
import_fluid = "import paddle\nimport paddle.fluid as fluid\n" import_fluid = "import paddle\nimport paddle.fluid as fluid\n"
source = import_fluid + source source = import_fluid + source
if six.PY2: f = tempfile.NamedTemporaryFile(
source = source.encode('utf-8') mode='w', suffix='.py', delete=False, encoding='utf-8')
f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False)
else:
f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8')
with f: with f:
module_name = os.path.basename(f.name[:-3]) module_name = os.path.basename(f.name[:-3])
f.write(source) f.write(source)
...@@ -505,10 +498,7 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -505,10 +498,7 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
atexit.register(lambda: remove_if_exit(f.name)) atexit.register(lambda: remove_if_exit(f.name))
atexit.register(lambda: remove_if_exit(f.name[:-3] + ".pyc")) atexit.register(lambda: remove_if_exit(f.name[:-3] + ".pyc"))
if six.PY2: module = SourceFileLoader(module_name, f.name).load_module()
module = imp.load_source(module_name, f.name)
else:
module = SourceFileLoader(module_name, f.name).load_module()
func_name = dyfunc.__name__ func_name = dyfunc.__name__
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# through 'func_name'. So set the special function name '__i_m_p_l__'. # through 'func_name'. So set the special function name '__i_m_p_l__'.
......
...@@ -98,17 +98,9 @@ def create_fill_constant_node(name, value): ...@@ -98,17 +98,9 @@ def create_fill_constant_node(name, value):
func_code += "dtype='float64', value={})".format(value) func_code += "dtype='float64', value={})".format(value)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
if six.PY2: if isinstance(value, int):
if isinstance(value, int): func_code += "dtype='int64', value={})".format(value)
func_code += "dtype='int32', value={})".format(value) return gast.parse(func_code).body[0]
return gast.parse(func_code).body[0]
if isinstance(value, long):
func_code += "dtype='int64', value={})".format(value)
return gast.parse(func_code).body[0]
else:
if isinstance(value, int):
func_code += "dtype='int64', value={})".format(value)
return gast.parse(func_code).body[0]
def to_static_variable(x): def to_static_variable(x):
......
...@@ -20,7 +20,6 @@ from ..layers.layer_function_generator import OpProtoHolder ...@@ -20,7 +20,6 @@ from ..layers.layer_function_generator import OpProtoHolder
from . import no_grad from . import no_grad
import numpy as np import numpy as np
import six
import warnings import warnings
_supported_int_dtype_ = [ _supported_int_dtype_ = [
...@@ -121,10 +120,7 @@ def monkey_patch_math_varbase(): ...@@ -121,10 +120,7 @@ def monkey_patch_math_varbase():
assert numel == 1, "only one element variable can be converted to long." assert numel == 1, "only one element variable can be converted to long."
tensor = var.value().get_tensor() tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized" assert tensor._is_initialized(), "variable's tensor is not initialized"
if six.PY2: return int(var.numpy().flatten()[0])
return long(var.numpy().flatten()[0])
else:
return int(var.numpy().flatten()[0])
def _int_(var): def _int_(var):
numel = np.prod(var.shape) numel = np.prod(var.shape)
...@@ -141,10 +137,7 @@ def monkey_patch_math_varbase(): ...@@ -141,10 +137,7 @@ def monkey_patch_math_varbase():
assert numel == 1, "only one element variable can be converted to python index." assert numel == 1, "only one element variable can be converted to python index."
tensor = var.value().get_tensor() tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized" assert tensor._is_initialized(), "variable's tensor is not initialized"
if six.PY2: return int(var.numpy().flatten()[0])
return long(var.numpy().flatten()[0])
else:
return int(var.numpy().flatten()[0])
@property @property
def _ndim_(var): def _ndim_(var):
......
...@@ -1940,8 +1940,7 @@ def _pickle_loads_mac(path, f): ...@@ -1940,8 +1940,7 @@ def _pickle_loads_mac(path, f):
max_bytes = 2**30 max_bytes = 2**30
for _ in range(0, file_size, max_bytes): for _ in range(0, file_size, max_bytes):
pickle_bytes += f.read(max_bytes) pickle_bytes += f.read(max_bytes)
load_result = pickle.loads(pickle_bytes) if six.PY2 else pickle.loads( load_result = pickle.loads(pickle_bytes, encoding='latin1')
pickle_bytes, encoding='latin1')
return load_result return load_result
...@@ -2113,8 +2112,7 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -2113,8 +2112,7 @@ def load(program, model_path, executor=None, var_list=None):
if sys.platform == 'darwin' and sys.version_info.major == 3: if sys.platform == 'darwin' and sys.version_info.major == 3:
load_dict = _pickle_loads_mac(parameter_file_name, f) load_dict = _pickle_loads_mac(parameter_file_name, f)
else: else:
load_dict = pickle.load(f) if six.PY2 else pickle.load( load_dict = pickle.load(f, encoding='latin1')
f, encoding='latin1')
load_dict = _pack_loaded_dict(load_dict) load_dict = _pack_loaded_dict(load_dict)
for v in parameter_list: for v in parameter_list:
assert v.name in load_dict, \ assert v.name in load_dict, \
...@@ -2135,8 +2133,7 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -2135,8 +2133,7 @@ def load(program, model_path, executor=None, var_list=None):
optimizer_var_list, global_scope(), executor._default_executor) optimizer_var_list, global_scope(), executor._default_executor)
with open(opt_file_name, 'rb') as f: with open(opt_file_name, 'rb') as f:
load_dict = pickle.load(f) if six.PY2 else pickle.load( load_dict = pickle.load(f, encoding='latin1')
f, encoding='latin1')
for v in optimizer_var_list: for v in optimizer_var_list:
assert v.name in load_dict, \ assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format( "Can not find [{}] in model file [{}]".format(
...@@ -2297,15 +2294,13 @@ def load_program_state(model_path, var_list=None): ...@@ -2297,15 +2294,13 @@ def load_program_state(model_path, var_list=None):
if sys.platform == 'darwin' and sys.version_info.major == 3: if sys.platform == 'darwin' and sys.version_info.major == 3:
para_dict = _pickle_loads_mac(parameter_file_name, f) para_dict = _pickle_loads_mac(parameter_file_name, f)
else: else:
para_dict = pickle.load(f) if six.PY2 else pickle.load( para_dict = pickle.load(f, encoding='latin1')
f, encoding='latin1')
para_dict = _pack_loaded_dict(para_dict) para_dict = _pack_loaded_dict(para_dict)
opt_file_name = model_prefix + ".pdopt" opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name): if os.path.exists(opt_file_name):
with open(opt_file_name, 'rb') as f: with open(opt_file_name, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load( opti_dict = pickle.load(f, encoding='latin1')
f, encoding='latin1')
para_dict.update(opti_dict) para_dict.update(opti_dict)
......
...@@ -16,9 +16,7 @@ from __future__ import print_function ...@@ -16,9 +16,7 @@ from __future__ import print_function
import math import math
import numpy import numpy
import six
import warnings import warnings
from six.moves import reduce
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
...@@ -134,14 +132,9 @@ def create_parameter(shape, ...@@ -134,14 +132,9 @@ def create_parameter(shape,
""" """
check_type(shape, 'shape', (list, tuple, numpy.ndarray), 'create_parameter') check_type(shape, 'shape', (list, tuple, numpy.ndarray), 'create_parameter')
for item in shape: for item in shape:
if six.PY2: check_type(item, 'item of shape',
check_type(item, 'item of shape', (int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
(int, long, numpy.uint8, numpy.int8, numpy.int16, numpy.int64), 'create_parameter')
numpy.int32, numpy.int64), 'create_parameter')
else:
check_type(item, 'item of shape',
(int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
numpy.int64), 'create_parameter')
check_dtype(dtype, 'dtype', [ check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32', 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32',
...@@ -194,14 +187,9 @@ def create_global_var(shape, ...@@ -194,14 +187,9 @@ def create_global_var(shape,
check_type(shape, 'shape', (list, tuple, numpy.ndarray), check_type(shape, 'shape', (list, tuple, numpy.ndarray),
'create_global_var') 'create_global_var')
for item in shape: for item in shape:
if six.PY2: check_type(item, 'item of shape',
check_type(item, 'item of shape', (int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
(int, long, numpy.uint8, numpy.int8, numpy.int16, numpy.int64), 'create_global_var')
numpy.int32, numpy.int64), 'create_global_var')
else:
check_type(item, 'item of shape',
(int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
numpy.int64), 'create_global_var')
check_dtype(dtype, 'dtype', [ check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32', 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32',
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
import sys import sys
import signal import signal
import atexit import atexit
...@@ -20,10 +19,7 @@ import atexit ...@@ -20,10 +19,7 @@ import atexit
from . import core from . import core
# NOTE: queue has a different name in python2 and python3 # NOTE: queue has a different name in python2 and python3
if six.PY2: import queue
import Queue as queue
else:
import queue
# multi-process worker check indices queue interval, avoid # multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading # hanging in subprocess data loading
......
...@@ -38,10 +38,7 @@ import multiprocessing ...@@ -38,10 +38,7 @@ import multiprocessing
import signal import signal
# NOTE: queue has a different name in python2 and python3 # NOTE: queue has a different name in python2 and python3
if six.PY2: import queue
import Queue as queue
else:
import queue
# NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process # NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process
QUEUE_GET_TIMEOUT = 60 QUEUE_GET_TIMEOUT = 60
......
...@@ -169,10 +169,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2): ...@@ -169,10 +169,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor( var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor(
)) ))
if six.PY2: sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
print(pickle.dumps(np.ravel(var).tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
elif save_mode == "DIST": elif save_mode == "DIST":
skip_steps = int(os.getenv("SKIP_STEPS")) skip_steps = int(os.getenv("SKIP_STEPS"))
...@@ -191,10 +188,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2): ...@@ -191,10 +188,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
continue continue
loss, = exe.run(fetch_list=[avg_cost.name], loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(data)) feed=feeder.feed(data))
if six.PY2: sys.stdout.buffer.write(pickle.dumps(loss.tolist()))
print(pickle.dumps(loss.tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(loss.tolist()))
else: else:
raise Exception("save_mode must be LOCAL or DIST") raise Exception("save_mode must be LOCAL or DIST")
......
...@@ -24,7 +24,6 @@ import paddle.distributed.fleet.base.role_maker as role_maker ...@@ -24,7 +24,6 @@ import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.distributed.fleet.meta_optimizers.sharding as sharding import paddle.distributed.fleet.meta_optimizers.sharding as sharding
import os import os
import six
import sys import sys
import pickle import pickle
...@@ -81,10 +80,7 @@ def runtime_main(): ...@@ -81,10 +80,7 @@ def runtime_main():
exe, dirname, main_program=train_prog, filename=None) exe, dirname, main_program=train_prog, filename=None)
out_losses = [] out_losses = []
if six.PY2: sys.stdout.buffer.write(pickle.dumps(out_losses))
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -44,14 +44,9 @@ DATA_MD5 = '29ebfc94f11aea9362bbb7f5e9d86b8a' ...@@ -44,14 +44,9 @@ DATA_MD5 = '29ebfc94f11aea9362bbb7f5e9d86b8a'
# Load dictionary. # Load dictionary.
def load_vocab(filename): def load_vocab(filename):
vocab = {} vocab = {}
if six.PY2: with open(filename, 'r', encoding="utf-8") as f:
with open(filename, 'r') as f: for idx, line in enumerate(f):
for idx, line in enumerate(f): vocab[line.strip()] = idx
vocab[line.strip()] = idx
else:
with open(filename, 'r', encoding="utf-8") as f:
for idx, line in enumerate(f):
vocab[line.strip()] = idx
return vocab return vocab
......
...@@ -21,18 +21,10 @@ import sys ...@@ -21,18 +21,10 @@ import sys
import unittest import unittest
import gast import gast
import six
import paddle import paddle
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from unittest import mock
# TODO(liym27): library mock needs to be installed separately in PY2,
# but CI environment has not installed mock yet.
# After discuss with Tian Shuo, now use mock only in PY3, and use it in PY2 after CI installs it.
if six.PY3:
from unittest import mock
# else:
# import mock
class TestLoggingUtils(unittest.TestCase): class TestLoggingUtils(unittest.TestCase):
...@@ -112,7 +104,7 @@ class TestLoggingUtils(unittest.TestCase): ...@@ -112,7 +104,7 @@ class TestLoggingUtils(unittest.TestCase):
ast_code, "TestTransformer") ast_code, "TestTransformer")
def test_log_message(self): def test_log_message(self):
stream = io.BytesIO() if six.PY2 else io.StringIO() stream = io.StringIO()
log = self.translator_logger.logger log = self.translator_logger.logger
stdout_handler = logging.StreamHandler(stream) stdout_handler = logging.StreamHandler(stream)
log.addHandler(stdout_handler) log.addHandler(stdout_handler)
...@@ -122,39 +114,36 @@ class TestLoggingUtils(unittest.TestCase): ...@@ -122,39 +114,36 @@ class TestLoggingUtils(unittest.TestCase):
log_msg_1 = "test_log_1" log_msg_1 = "test_log_1"
log_msg_2 = "test_log_2" log_msg_2 = "test_log_2"
if six.PY3: with mock.patch.object(sys, 'stdout', stream):
with mock.patch.object(sys, 'stdout', stream): logging_utils.set_verbosity(1, False)
logging_utils.set_verbosity(1, False) logging_utils.warn(warn_msg)
logging_utils.warn(warn_msg) logging_utils.error(error_msg)
logging_utils.error(error_msg) logging_utils.log(1, log_msg_1)
logging_utils.log(1, log_msg_1) logging_utils.log(2, log_msg_2)
logging_utils.log(2, log_msg_2)
result_msg = '\n'.join( result_msg = '\n'.join(
[warn_msg, error_msg, "(Level 1) " + log_msg_1, ""]) [warn_msg, error_msg, "(Level 1) " + log_msg_1, ""])
self.assertEqual(result_msg, stream.getvalue()) self.assertEqual(result_msg, stream.getvalue())
def test_log_transformed_code(self): def test_log_transformed_code(self):
source_code = "x = 3" source_code = "x = 3"
ast_code = gast.parse(source_code) ast_code = gast.parse(source_code)
stream = io.BytesIO() if six.PY2 else io.StringIO() stream = io.StringIO()
log = self.translator_logger.logger log = self.translator_logger.logger
stdout_handler = logging.StreamHandler(stream) stdout_handler = logging.StreamHandler(stream)
log.addHandler(stdout_handler) log.addHandler(stdout_handler)
if six.PY3: with mock.patch.object(sys, 'stdout', stream):
with mock.patch.object(sys, 'stdout', stream): paddle.jit.set_code_level(1)
paddle.jit.set_code_level(1) logging_utils.log_transformed_code(1, ast_code,
logging_utils.log_transformed_code(1, ast_code, "BasicApiTransformer")
"BasicApiTransformer")
paddle.jit.set_code_level() paddle.jit.set_code_level()
logging_utils.log_transformed_code( logging_utils.log_transformed_code(logging_utils.LOG_AllTransformer,
logging_utils.LOG_AllTransformer, ast_code, ast_code, "All Transformers")
"All Transformers")
self.assertIn(source_code, stream.getvalue()) self.assertIn(source_code, stream.getvalue())
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from __future__ import print_function from __future__ import print_function
import gast import gast
import six
import unittest import unittest
import numpy as np import numpy as np
...@@ -58,18 +57,9 @@ class TestVariableTransFunc(unittest.TestCase): ...@@ -58,18 +57,9 @@ class TestVariableTransFunc(unittest.TestCase):
source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)" source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)"
self.assertEqual(ast_to_source_code(node).strip(), source) self.assertEqual(ast_to_source_code(node).strip(), source)
if six.PY2: node = create_fill_constant_node("c", 4293)
node = create_fill_constant_node("c", 214) source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)"
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int32', value=214)" self.assertEqual(ast_to_source_code(node).strip(), source)
self.assertEqual(ast_to_source_code(node).strip(), source)
node = create_fill_constant_node("d", long(10086))
source = "d = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=10086)"
self.assertEqual(ast_to_source_code(node).strip(), source)
else:
node = create_fill_constant_node("c", 4293)
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)"
self.assertEqual(ast_to_source_code(node).strip(), source)
self.assertIsNone(create_fill_constant_node("e", None)) self.assertIsNone(create_fill_constant_node("e", None))
self.assertIsNone(create_fill_constant_node("e", [])) self.assertIsNone(create_fill_constant_node("e", []))
......
...@@ -18,14 +18,12 @@ import unittest ...@@ -18,14 +18,12 @@ import unittest
import time import time
import argparse import argparse
import os import os
import six
import sys import sys
import subprocess import subprocess
import traceback import traceback
import functools import functools
import pickle import pickle
from contextlib import closing from contextlib import closing
from six import string_types
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen import paddle.fluid.unique_name as nameGen
from paddle.fluid import core from paddle.fluid import core
...@@ -113,10 +111,7 @@ class TestCollectiveRunnerBase(object): ...@@ -113,10 +111,7 @@ class TestCollectiveRunnerBase(object):
out = exe.run(train_prog, out = exe.run(train_prog,
feed={'tindata': indata}, feed={'tindata': indata},
fetch_list=[result.name]) fetch_list=[result.name])
if six.PY2: sys.stdout.buffer.write(pickle.dumps(out))
print(pickle.dumps(out))
else:
sys.stdout.buffer.write(pickle.dumps(out))
def runtime_main(test_class, col_type, sub_type): def runtime_main(test_class, col_type, sub_type):
......
...@@ -18,14 +18,12 @@ import unittest ...@@ -18,14 +18,12 @@ import unittest
import time import time
import argparse import argparse
import os import os
import six
import sys import sys
import subprocess import subprocess
import traceback import traceback
import functools import functools
import pickle import pickle
from contextlib import closing from contextlib import closing
from six import string_types
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen import paddle.fluid.unique_name as nameGen
...@@ -69,10 +67,7 @@ class TestCollectiveAPIRunnerBase(object): ...@@ -69,10 +67,7 @@ class TestCollectiveAPIRunnerBase(object):
else: else:
out = self.get_model(train_prog, startup_prog, rank, indata) out = self.get_model(train_prog, startup_prog, rank, indata)
#print(out, sys.stderr) #print(out, sys.stderr)
if six.PY2: sys.stdout.buffer.write(pickle.dumps(out))
print(pickle.dumps(out))
else:
sys.stdout.buffer.write(pickle.dumps(out))
def runtime_main(test_class, col_type): def runtime_main(test_class, col_type):
......
...@@ -18,14 +18,12 @@ import unittest ...@@ -18,14 +18,12 @@ import unittest
import time import time
import argparse import argparse
import os import os
import six
import sys import sys
import subprocess import subprocess
import traceback import traceback
import functools import functools
import pickle import pickle
from contextlib import closing from contextlib import closing
from six import string_types
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen import paddle.fluid.unique_name as nameGen
from paddle.fluid import core from paddle.fluid import core
...@@ -37,7 +35,6 @@ class TestCollectiveRunnerBase(object): ...@@ -37,7 +35,6 @@ class TestCollectiveRunnerBase(object):
"get model should be implemented by child class.") "get model should be implemented by child class.")
def wait_server_ready(self, endpoints): def wait_server_ready(self, endpoints):
assert not isinstance(endpoints, string_types)
while True: while True:
all_ok = True all_ok = True
not_ready_endpoints = [] not_ready_endpoints = []
...@@ -115,10 +112,7 @@ class TestCollectiveRunnerBase(object): ...@@ -115,10 +112,7 @@ class TestCollectiveRunnerBase(object):
out = exe.run(train_prog, out = exe.run(train_prog,
feed={'tindata': indata}, feed={'tindata': indata},
fetch_list=[result.name]) fetch_list=[result.name])
if six.PY2: sys.stdout.buffer.write(pickle.dumps(out))
print(pickle.dumps(out))
else:
sys.stdout.buffer.write(pickle.dumps(out))
def runtime_main(test_class, col_type, sub_type): def runtime_main(test_class, col_type, sub_type):
......
...@@ -16,465 +16,230 @@ from __future__ import print_function ...@@ -16,465 +16,230 @@ from __future__ import print_function
import unittest import unittest
import paddle.compat as cpt import paddle.compat as cpt
import six
class TestCompatible(unittest.TestCase): class TestCompatible(unittest.TestCase):
def test_type(self): def test_type(self):
if six.PY2: self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.int_type, int) self.assertEqual(cpt.long_type, int)
self.assertEqual(cpt.long_type, long)
else:
self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, int)
def test_to_text(self): def test_to_text(self):
# Only support python2.x and python3.x now self.assertIsNone(cpt.to_text(None))
self.assertTrue(six.PY2 | six.PY3)
self.assertTrue(isinstance(cpt.to_text(str("")), str))
if six.PY2: self.assertTrue(isinstance(cpt.to_text(str("123")), str))
# check None self.assertTrue(isinstance(cpt.to_text(b""), str))
self.assertIsNone(cpt.to_text(None)) self.assertTrue(isinstance(cpt.to_text(b""), str))
self.assertTrue(isinstance(cpt.to_text(u""), str))
# check all string related types self.assertTrue(isinstance(cpt.to_text(u""), str))
self.assertTrue(isinstance(cpt.to_text(str("")), unicode))
self.assertTrue(isinstance(cpt.to_text(str("123")), unicode)) self.assertEqual("", cpt.to_text(str("")))
self.assertTrue(isinstance(cpt.to_text(b""), unicode)) self.assertEqual("123", cpt.to_text(str("123")))
self.assertTrue(isinstance(cpt.to_text(b""), unicode)) self.assertEqual("", cpt.to_text(b""))
self.assertTrue(isinstance(cpt.to_text(u""), unicode)) self.assertEqual("123", cpt.to_text(b"123"))
self.assertTrue(isinstance(cpt.to_text(u""), unicode)) self.assertEqual("", cpt.to_text(u""))
self.assertEqual("123", cpt.to_text(u"123"))
self.assertEqual(u"", cpt.to_text(str("")))
self.assertEqual(u"123", cpt.to_text(str("123"))) # check list types, not inplace
self.assertEqual(u"", cpt.to_text(b"")) l = [""]
self.assertEqual(u"123", cpt.to_text(b"123")) l2 = cpt.to_text(l)
self.assertEqual(u"", cpt.to_text(u"")) self.assertTrue(isinstance(l2, list))
self.assertEqual(u"123", cpt.to_text(u"123")) self.assertFalse(l is l2)
self.assertEqual(l, l2)
# check list types, not inplace self.assertEqual([""], l2)
l = [""] l = ["", "123"]
l2 = cpt.to_text(l) l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertEqual(l, l2)
self.assertEqual([u""], l2) self.assertEqual(["", "123"], l2)
l = ["", "123"] l = ["", b"123", u"321"]
l2 = cpt.to_text(l) l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertNotEqual(l, l2)
self.assertEqual([u"", u"123"], l2) self.assertEqual(["", "123", "321"], l2)
l = ["", b'123', u"321"]
l2 = cpt.to_text(l) # check list types, inplace
self.assertTrue(isinstance(l2, list)) l = [""]
self.assertFalse(l is l2) l2 = cpt.to_text(l, inplace=True)
self.assertEqual(l, l2) self.assertTrue(isinstance(l2, list))
self.assertEqual([u"", u"123", u"321"], l2) self.assertTrue(l is l2)
for i in l2: self.assertEqual(l, l2)
self.assertTrue(isinstance(i, unicode)) self.assertEqual([""], l2)
l = ["", b"123"]
# check list types, inplace l2 = cpt.to_text(l, inplace=True)
l = [""] self.assertTrue(isinstance(l2, list))
l2 = cpt.to_text(l, inplace=True) self.assertTrue(l is l2)
self.assertTrue(isinstance(l2, list)) self.assertEqual(l, l2)
self.assertTrue(l is l2) self.assertEqual(["", "123"], l2)
self.assertEqual(l, l2) l = ["", b"123", u"321"]
self.assertEqual([u""], l2) l2 = cpt.to_text(l, inplace=True)
l = ["", "123"] self.assertTrue(isinstance(l2, list))
l2 = cpt.to_text(l, inplace=True) self.assertTrue(l is l2)
self.assertTrue(isinstance(l2, list)) self.assertEqual(l, l2)
self.assertTrue(l is l2) self.assertEqual(["", "123", "321"], l2)
self.assertEqual(l, l2) for i in l2:
self.assertEqual([u"", u"123"], l2) self.assertTrue(isinstance(i, str))
l = ["", b"123", u"321"]
l2 = cpt.to_text(l, inplace=True) # check set types, not inplace
self.assertTrue(isinstance(l2, list)) l = set("")
self.assertTrue(l is l2) l2 = cpt.to_text(l, inplace=False)
self.assertEqual(l, l2) self.assertTrue(isinstance(l2, set))
self.assertEqual([u"", u"123", u"321"], l2) self.assertFalse(l is l2)
self.assertEqual(l, l2)
# check set types, not inplace self.assertEqual(set(""), l2)
l = set("") l = set([b"", b"123"])
l2 = cpt.to_text(l, inplace=False) l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertNotEqual(l, l2)
self.assertEqual(set(u""), l2) self.assertEqual(set(["", "123"]), l2)
l = set([b"", b"123"]) l = set(["", b"123", u"321"])
l2 = cpt.to_text(l, inplace=False) l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertNotEqual(l, l2)
self.assertEqual(set([u"", u"123"]), l2) self.assertEqual(set(["", "123", "321"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_text(l, inplace=False) # check set types, inplace
self.assertTrue(isinstance(l2, set)) l = set("")
self.assertFalse(l is l2) l2 = cpt.to_text(l, inplace=True)
self.assertEqual(l, l2) self.assertTrue(isinstance(l2, set))
self.assertEqual(set([u"", u"123", u"321"]), l2) self.assertTrue(l is l2)
for i in l2: self.assertEqual(l, l2)
self.assertTrue(isinstance(i, unicode)) self.assertEqual(set(""), l2)
l = set([b"", b"123"])
# check set types, inplace l2 = cpt.to_text(l, inplace=True)
l = set("") self.assertTrue(isinstance(l2, set))
l2 = cpt.to_text(l, inplace=True) self.assertTrue(l is l2)
self.assertTrue(isinstance(l2, set)) self.assertEqual(l, l2)
self.assertTrue(l is l2) self.assertEqual(set(["", "123"]), l2)
self.assertEqual(l, l2) l = set(["", b"123", u"321"])
self.assertEqual(set(u""), l2) l2 = cpt.to_text(l, inplace=True)
l = set([b"", b"123"]) self.assertTrue(isinstance(l2, set))
l2 = cpt.to_text(l, inplace=True) self.assertTrue(l is l2)
self.assertTrue(isinstance(l2, set)) self.assertEqual(l, l2)
self.assertTrue(l is l2) self.assertEqual(set(["", "123", "321"]), l2)
self.assertEqual(l, l2) for i in l2:
self.assertEqual(set([u"", u"123"]), l2) self.assertTrue(isinstance(i, str))
l = set(["", b"123", u"321"])
l2 = cpt.to_text(l, inplace=True) # check dict types, not inplace
self.assertTrue(isinstance(l2, set)) l = {"": ""}
self.assertTrue(l is l2) l2 = cpt.to_text(l, inplace=False)
self.assertEqual(l, l2) self.assertTrue(isinstance(l2, dict))
self.assertEqual(set([u"", u"123", u"321"]), l2) self.assertFalse(l is l2)
self.assertEqual(l, l2)
# check dict types, not inplace self.assertEqual({"": ""}, l2)
l = {"": ""}
l2 = cpt.to_text(l, inplace=False) # check dict types, inplace
self.assertTrue(isinstance(l2, dict)) l = {"": ""}
self.assertFalse(l is l2) l2 = cpt.to_text(l, inplace=True)
self.assertEqual(l, l2) self.assertTrue(isinstance(l2, dict))
self.assertEqual({"": ""}, l2) self.assertTrue(l is l2)
self.assertEqual(l, l2)
# check dict types, inplace self.assertEqual({"": ""}, l2)
l = {"": ""}
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, dict))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual({"": ""}, l2)
elif six.PY3:
self.assertIsNone(cpt.to_text(None))
self.assertTrue(isinstance(cpt.to_text(str("")), str))
self.assertTrue(isinstance(cpt.to_text(str("123")), str))
self.assertTrue(isinstance(cpt.to_text(b""), str))
self.assertTrue(isinstance(cpt.to_text(b""), str))
self.assertTrue(isinstance(cpt.to_text(u""), str))
self.assertTrue(isinstance(cpt.to_text(u""), str))
self.assertEqual("", cpt.to_text(str("")))
self.assertEqual("123", cpt.to_text(str("123")))
self.assertEqual("", cpt.to_text(b""))
self.assertEqual("123", cpt.to_text(b"123"))
self.assertEqual("", cpt.to_text(u""))
self.assertEqual("123", cpt.to_text(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([""], l2)
l = ["", "123"]
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(["", "123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(["", "123", "321"], l2)
# check list types, inplace
l = [""]
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([""], l2)
l = ["", b"123"]
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(["", "123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(["", "123", "321"], l2)
for i in l2:
self.assertTrue(isinstance(i, str))
# check set types, not inplace
l = set("")
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(""), l2)
l = set([b"", b"123"])
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set(["", "123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set(["", "123", "321"]), l2)
# check set types, inplace
l = set("")
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(""), l2)
l = set([b"", b"123"])
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(["", "123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(["", "123", "321"]), l2)
for i in l2:
self.assertTrue(isinstance(i, str))
# check dict types, not inplace
l = {"": ""}
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, dict))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual({"": ""}, l2)
# check dict types, inplace
l = {"": ""}
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, dict))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual({"": ""}, l2)
def test_to_bytes(self): def test_to_bytes(self):
# Only support python2.x and python3.x now self.assertIsNone(cpt.to_bytes(None))
self.assertTrue(six.PY2 | six.PY3)
self.assertTrue(isinstance(cpt.to_bytes(str("")), bytes))
if six.PY2: self.assertTrue(isinstance(cpt.to_bytes(str("123")), bytes))
# check None self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertIsNone(cpt.to_bytes(None)) self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
# check all string related types self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(str("")), bytes))
self.assertTrue(isinstance(cpt.to_bytes(str("123")), bytes)) self.assertEqual(b"", cpt.to_bytes(str("")))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes)) self.assertEqual(b"123", cpt.to_bytes(str("123")))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes)) self.assertEqual(b"", cpt.to_bytes(b""))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes)) self.assertEqual(b"123", cpt.to_bytes(b"123"))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes)) self.assertEqual(b"", cpt.to_bytes(u""))
self.assertEqual(b"123", cpt.to_bytes(u"123"))
self.assertEqual(b"", cpt.to_bytes(str("")))
self.assertEqual(b"123", cpt.to_bytes(str("123"))) # check list types, not inplace
self.assertEqual(b"", cpt.to_bytes(b"")) l = [""]
self.assertEqual(b"123", cpt.to_bytes(b"123")) l2 = cpt.to_bytes(l)
self.assertEqual(b"", cpt.to_bytes(u"")) self.assertTrue(isinstance(l2, list))
self.assertEqual(b"123", cpt.to_bytes(u"123")) self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
# check list types, not inplace self.assertEqual([b""], l2)
l = [""] l = ["", "123"]
l2 = cpt.to_bytes(l) l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertNotEqual(l, l2)
self.assertEqual([b""], l2) self.assertEqual([b"", b"123"], l2)
l = ["", "123"] l = ["", b"123", u"321"]
l2 = cpt.to_bytes(l) l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list)) self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertNotEqual(l, l2)
self.assertEqual([b"", b"123"], l2) self.assertEqual([b"", b"123", b"321"], l2)
l = ["", b'123', u"321"]
l2 = cpt.to_bytes(l) # check list types, inplace
self.assertTrue(isinstance(l2, list)) l = [""]
self.assertFalse(l is l2) l2 = cpt.to_bytes(l, inplace=True)
self.assertEqual(l, l2) self.assertTrue(isinstance(l2, list))
self.assertEqual([b"", b"123", b"321"], l2) self.assertTrue(l is l2)
for i in l2: self.assertEqual(l, l2)
self.assertTrue(isinstance(i, bytes)) self.assertEqual([b""], l2)
l = ["", b"123"]
# check list types, inplace l2 = cpt.to_bytes(l, inplace=True)
l = [""] self.assertTrue(isinstance(l2, list))
l2 = cpt.to_bytes(l, inplace=True) self.assertTrue(l is l2)
self.assertTrue(isinstance(l2, list)) self.assertEqual(l, l2)
self.assertTrue(l is l2) self.assertEqual([b"", b"123"], l2)
self.assertEqual(l, l2) l = ["", b"123", u"321"]
self.assertEqual([b""], l2) l2 = cpt.to_bytes(l, inplace=True)
l = ["", "123"] self.assertTrue(isinstance(l2, list))
l2 = cpt.to_bytes(l, inplace=True) self.assertTrue(l is l2)
self.assertTrue(isinstance(l2, list)) self.assertEqual(l, l2)
self.assertTrue(l is l2) self.assertEqual([b"", b"123", b"321"], l2)
self.assertEqual(l, l2) for i in l2:
self.assertEqual([b"", b"123"], l2) self.assertTrue(isinstance(i, bytes))
l = ["", b"123", u"321"]
l2 = cpt.to_bytes(l, inplace=True) # check set types, not inplace
self.assertTrue(isinstance(l2, list)) l = set([""])
self.assertTrue(l is l2) l2 = cpt.to_bytes(l, inplace=False)
self.assertEqual(l, l2) self.assertTrue(isinstance(l2, set))
self.assertEqual([b"", b"123", b"321"], l2) self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
# check set types, not inplace self.assertEqual(set([b""]), l2)
l = set("") l = set([u"", u"123"])
l2 = cpt.to_bytes(l, inplace=False) l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertNotEqual(l, l2)
self.assertEqual(set(b""), l2) self.assertEqual(set([b"", b"123"]), l2)
l = set([b"", b"123"]) l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=False) l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set)) self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2) self.assertFalse(l is l2)
self.assertEqual(l, l2) self.assertNotEqual(l, l2)
self.assertEqual(set([b"", b"123"]), l2) self.assertEqual(set([b"", b"123", b"321"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=False) # check set types, inplace
self.assertTrue(isinstance(l2, set)) l = set("")
self.assertFalse(l is l2) l2 = cpt.to_bytes(l, inplace=True)
self.assertEqual(l, l2) self.assertTrue(isinstance(l2, set))
self.assertEqual(set([b"", b"123", b"321"]), l2) self.assertTrue(l is l2)
for i in l2: self.assertEqual(l, l2)
self.assertTrue(isinstance(i, bytes)) self.assertEqual(set(b""), l2)
l = set([u"", u"123"])
# check set types, inplace l2 = cpt.to_bytes(l, inplace=True)
l = set("") self.assertTrue(isinstance(l2, set))
l2 = cpt.to_bytes(l, inplace=True) self.assertTrue(l is l2)
self.assertTrue(isinstance(l2, set)) self.assertEqual(l, l2)
self.assertTrue(l is l2) self.assertEqual(set([b"", b"123"]), l2)
self.assertEqual(l, l2) l = set(["", b"123", u"321"])
self.assertEqual(set(b""), l2) l2 = cpt.to_bytes(l, inplace=True)
l = set([b"", b"123"]) self.assertTrue(isinstance(l2, set))
l2 = cpt.to_bytes(l, inplace=True) self.assertTrue(l is l2)
self.assertTrue(isinstance(l2, set)) self.assertEqual(l, l2)
self.assertTrue(l is l2) self.assertEqual(set([b"", b"123", b"321"]), l2)
self.assertEqual(l, l2) for i in l2:
self.assertEqual(set([b"", b"123"]), l2) self.assertTrue(isinstance(i, bytes))
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123", b"321"]), l2)
elif six.PY3:
self.assertIsNone(cpt.to_bytes(None))
self.assertTrue(isinstance(cpt.to_bytes(str("")), bytes))
self.assertTrue(isinstance(cpt.to_bytes(str("123")), bytes))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
self.assertEqual(b"", cpt.to_bytes(str("")))
self.assertEqual(b"123", cpt.to_bytes(str("123")))
self.assertEqual(b"", cpt.to_bytes(b""))
self.assertEqual(b"123", cpt.to_bytes(b"123"))
self.assertEqual(b"", cpt.to_bytes(u""))
self.assertEqual(b"123", cpt.to_bytes(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual([b""], l2)
l = ["", "123"]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual([b"", b"123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual([b"", b"123", b"321"], l2)
# check list types, inplace
l = [""]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b""], l2)
l = ["", b"123"]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123", b"321"], l2)
for i in l2:
self.assertTrue(isinstance(i, bytes))
# check set types, not inplace
l = set([""])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set([b""]), l2)
l = set([u"", u"123"])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set([b"", b"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set([b"", b"123", b"321"]), l2)
# check set types, inplace
l = set("")
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(b""), l2)
l = set([u"", u"123"])
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123", b"321"]), l2)
for i in l2:
self.assertTrue(isinstance(i, bytes))
def test_round(self): def test_round(self):
self.assertEqual(3.0, cpt.round(3.4)) self.assertEqual(3.0, cpt.round(3.4))
...@@ -500,37 +265,17 @@ class TestCompatible(unittest.TestCase): ...@@ -500,37 +265,17 @@ class TestCompatible(unittest.TestCase):
def test_get_exception_message(self): def test_get_exception_message(self):
exception_message = "test_message" exception_message = "test_message"
self.assertRaises(AssertionError, cpt.get_exception_message, None) self.assertRaises(AssertionError, cpt.get_exception_message, None)
if six.PY2: try:
self.assertRaises(AttributeError, cpt.get_exception_message, raise RuntimeError(exception_message)
exception_message) except Exception as e:
try: self.assertEqual(exception_message, cpt.get_exception_message(e))
raise RuntimeError(exception_message) self.assertIsNotNone(e)
except Exception as e:
self.assertEqual(exception_message, try:
cpt.get_exception_message(e)) raise Exception(exception_message)
self.assertIsNotNone(e) except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
try: self.assertIsNotNone(e)
raise Exception(exception_message)
except Exception as e:
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e)
if six.PY3:
try:
raise RuntimeError(exception_message)
except Exception as e:
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e)
try:
raise Exception(exception_message)
except Exception as e:
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -44,19 +44,13 @@ DIST_UT_PORT = 0 ...@@ -44,19 +44,13 @@ DIST_UT_PORT = 0
def print_to_out(out_losses): def print_to_out(out_losses):
if six.PY2: sys.stdout.buffer.write(pickle.dumps(out_losses))
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
def print_to_err(class_name, log_str): def print_to_err(class_name, log_str):
localtime = time.asctime(time.localtime(time.time())) localtime = time.asctime(time.localtime(time.time()))
print_str = localtime + "\t" + class_name + "\t" + log_str print_str = localtime + "\t" + class_name + "\t" + log_str
if six.PY2: sys.stderr.buffer.write(pickle.dumps(print_str))
sys.stderr.write(pickle.dumps(print_str))
else:
sys.stderr.buffer.write(pickle.dumps(print_str))
def eprint(*args, **kwargs): def eprint(*args, **kwargs):
...@@ -151,10 +145,7 @@ class TestDistRunnerBase(object): ...@@ -151,10 +145,7 @@ class TestDistRunnerBase(object):
print_to_err(type(self).__name__, "run step %d finished" % i) print_to_err(type(self).__name__, "run step %d finished" % i)
print_to_err(type(self).__name__, "trainer run finished") print_to_err(type(self).__name__, "trainer run finished")
if six.PY2: sys.stdout.buffer.write(pickle.dumps(out_losses))
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
if args.save_model: if args.save_model:
model_save_dir = "/tmp" model_save_dir = "/tmp"
...@@ -251,10 +242,7 @@ class TestDistRunnerBase(object): ...@@ -251,10 +242,7 @@ class TestDistRunnerBase(object):
print_to_err(type(self).__name__, "trainer run finished") print_to_err(type(self).__name__, "trainer run finished")
print_to_err(type(self).__name__, "dist losses: {}".format(out_losses)) print_to_err(type(self).__name__, "dist losses: {}".format(out_losses))
if six.PY2: sys.stdout.buffer.write(pickle.dumps(out_losses))
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
def run_use_fleet_api_trainer(self, args): def run_use_fleet_api_trainer(self, args):
assert args.update_method == "nccl2" or "bkcl" assert args.update_method == "nccl2" or "bkcl"
...@@ -338,10 +326,7 @@ class TestDistRunnerBase(object): ...@@ -338,10 +326,7 @@ class TestDistRunnerBase(object):
print_to_err(type(self).__name__, "run step %d finished" % i) print_to_err(type(self).__name__, "run step %d finished" % i)
print_to_err(type(self).__name__, "trainer run finished") print_to_err(type(self).__name__, "trainer run finished")
if six.PY2: sys.stdout.buffer.write(pickle.dumps(out_losses))
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
if args.save_model: if args.save_model:
model_save_dir = "/tmp" model_save_dir = "/tmp"
......
...@@ -18,7 +18,6 @@ import unittest ...@@ -18,7 +18,6 @@ import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import six
import inspect import inspect
...@@ -241,10 +240,7 @@ class TestMathOpPatchesVarBase(unittest.TestCase): ...@@ -241,10 +240,7 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
a = fluid.dygraph.to_variable(np.array([100.1])) a = fluid.dygraph.to_variable(np.array([100.1]))
self.assertTrue(float(a) == 100.1) self.assertTrue(float(a) == 100.1)
self.assertTrue(int(a) == 100) self.assertTrue(int(a) == 100)
if six.PY2: self.assertTrue(int(a) == 100)
self.assertTrue(long(a) == 100)
else:
self.assertTrue(int(a) == 100)
def test_len(self): def test_len(self):
a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
......
...@@ -18,7 +18,6 @@ import unittest ...@@ -18,7 +18,6 @@ import unittest
import numpy as np import numpy as np
import os import os
import sys import sys
import six
from io import BytesIO from io import BytesIO
import paddle import paddle
...@@ -38,10 +37,7 @@ SEED = 10 ...@@ -38,10 +37,7 @@ SEED = 10
IMAGE_SIZE = 784 IMAGE_SIZE = 784
CLASS_NUM = 10 CLASS_NUM = 10
if six.PY2: LARGE_PARAM = 2**26
LARGE_PARAM = 2**2
else:
LARGE_PARAM = 2**26
def random_batch_reader(): def random_batch_reader():
...@@ -105,10 +101,7 @@ class TestSaveLoadLargeParameters(unittest.TestCase): ...@@ -105,10 +101,7 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
path = os.path.join("test_paddle_save_load_large_param_save", path = os.path.join("test_paddle_save_load_large_param_save",
"layer.pdparams") "layer.pdparams")
if six.PY2: protocol = 4
protocol = 2
else:
protocol = 4
paddle.save(save_dict, path, protocol=protocol) paddle.save(save_dict, path, protocol=protocol)
dict_load = paddle.load(path) dict_load = paddle.load(path)
# compare results before and after saving # compare results before and after saving
...@@ -926,9 +919,6 @@ class TestSaveLoadProgram(unittest.TestCase): ...@@ -926,9 +919,6 @@ class TestSaveLoadProgram(unittest.TestCase):
class TestSaveLoadLayer(unittest.TestCase): class TestSaveLoadLayer(unittest.TestCase):
def test_save_load_layer(self): def test_save_load_layer(self):
if six.PY2:
return
paddle.disable_static() paddle.disable_static()
inps = paddle.randn([1, IMAGE_SIZE], dtype='float32') inps = paddle.randn([1, IMAGE_SIZE], dtype='float32')
layer1 = LinearNet() layer1 = LinearNet()
......
...@@ -21,15 +21,10 @@ import paddle.fluid.framework as framework ...@@ -21,15 +21,10 @@ import paddle.fluid.framework as framework
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
import numpy as np import numpy as np
import six
import pickle import pickle
import os import os
# Python2.x no longer supports saving and loading large parameters. LARGE_PARAM = 2**26
if six.PY2:
LARGE_PARAM = 2
else:
LARGE_PARAM = 2**26
class TestStaticSaveLoadLargeParameters(unittest.TestCase): class TestStaticSaveLoadLargeParameters(unittest.TestCase):
...@@ -59,10 +54,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): ...@@ -59,10 +54,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
path = os.path.join("test_static_save_load_large_param", path = os.path.join("test_static_save_load_large_param",
"static_save") "static_save")
if six.PY2: protocol = 4
protocol = 2
else:
protocol = 4
paddle.fluid.save(prog, path, pickle_protocol=protocol) paddle.fluid.save(prog, path, pickle_protocol=protocol)
# set var to zero # set var to zero
for var in prog.list_vars(): for var in prog.list_vars():
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import six
import unittest import unittest
import paddle.nn as nn import paddle.nn as nn
import os import os
...@@ -50,10 +49,7 @@ class TestTracedLayerErrMsg(unittest.TestCase): ...@@ -50,10 +49,7 @@ class TestTracedLayerErrMsg(unittest.TestCase):
self.feature_size = 3 self.feature_size = 3
self.fc_size = 2 self.fc_size = 2
self.layer = self._train_simple_net() self.layer = self._train_simple_net()
if six.PY2: self.type_str = 'class'
self.type_str = 'type'
else:
self.type_str = 'class'
def test_trace_err(self): def test_trace_err(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
...@@ -192,7 +192,7 @@ class TestVarBase(unittest.TestCase): ...@@ -192,7 +192,7 @@ class TestVarBase(unittest.TestCase):
x = paddle.to_tensor(1, dtype='int64') x = paddle.to_tensor(1, dtype='int64')
self.assertEqual(x.item(), 1) self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), long if six.PY2 else int)) self.assertTrue(isinstance(x.item(), int))
x = paddle.to_tensor(True) x = paddle.to_tensor(True)
self.assertEqual(x.item(), True) self.assertEqual(x.item(), True)
......
...@@ -17,14 +17,10 @@ from __future__ import print_function ...@@ -17,14 +17,10 @@ from __future__ import print_function
import os import os
import collections import collections
import pickle import pickle
import six
import warnings import warnings
import sys import sys
import numpy as np import numpy as np
import copyreg
if not six.PY2:
import copyreg
import paddle import paddle
# deprecated module import # deprecated module import
...@@ -296,19 +292,14 @@ def _pickle_save(obj, f, protocol): ...@@ -296,19 +292,14 @@ def _pickle_save(obj, f, protocol):
for i in range(0, len(pickle_bytes), max_bytes): for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes]) f.write(pickle_bytes[i:i + max_bytes])
else: else:
if six.PY2: pickler = pickle.Pickler(f, protocol)
add_dispatch_table() pickler.dispatch_table = copyreg.dispatch_table.copy()
pickle_bytes = pickle.dump(obj, f, protocol)
pop_dispatch_table()
else:
pickler = pickle.Pickler(f, protocol)
pickler.dispatch_table = copyreg.dispatch_table.copy()
pickler.dispatch_table[core.VarBase] = reduce_varbase pickler.dispatch_table[core.VarBase] = reduce_varbase
pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor
pickler.dispatch_table[ParamBase] = reduce_varbase pickler.dispatch_table[ParamBase] = reduce_varbase
pickler.dispatch_table.update(dispatch_table_layer) pickler.dispatch_table.update(dispatch_table_layer)
pickler.dump(obj) pickler.dump(obj)
def _contain_x(obj, condition_func): def _contain_x(obj, condition_func):
...@@ -359,10 +350,7 @@ def _transformed_from_varbase(obj): ...@@ -359,10 +350,7 @@ def _transformed_from_varbase(obj):
# In paddle2.1 version, VarBase is saved as tuple(tensor.name, tensor.numpy()). # In paddle2.1 version, VarBase is saved as tuple(tensor.name, tensor.numpy()).
# When executing paddle.load, use this function to determine whether to restore to VarBase/LoDTensor. # When executing paddle.load, use this function to determine whether to restore to VarBase/LoDTensor.
if isinstance(obj, tuple) and len(obj) == 2: if isinstance(obj, tuple) and len(obj) == 2:
if six.PY2: name_types = str
name_types = (str, unicode)
else:
name_types = str
if isinstance(obj[0], name_types) and isinstance(obj[1], np.ndarray): if isinstance(obj[0], name_types) and isinstance(obj[1], np.ndarray):
return True return True
return False return False
...@@ -947,10 +935,7 @@ def load(path, **configs): ...@@ -947,10 +935,7 @@ def load(path, **configs):
if _is_memory_buffer(path) or os.path.isfile(path): if _is_memory_buffer(path) or os.path.isfile(path):
config = _parse_load_config(configs) config = _parse_load_config(configs)
if six.PY2: exception_type = pickle.UnpicklingError
exception_type = KeyError
else:
exception_type = pickle.UnpicklingError
try: try:
with _open_file_buffer(path, 'rb') as f: with _open_file_buffer(path, 'rb') as f:
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
...@@ -959,8 +944,7 @@ def load(path, **configs): ...@@ -959,8 +944,7 @@ def load(path, **configs):
) and sys.platform == 'darwin' and sys.version_info.major == 3: ) and sys.platform == 'darwin' and sys.version_info.major == 3:
load_result = _pickle_loads_mac(path, f) load_result = _pickle_loads_mac(path, f)
else: else:
load_result = pickle.load(f) if six.PY2 else pickle.load( load_result = pickle.load(f, encoding='latin1')
f, encoding='latin1')
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise. # TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
if isinstance(load_result, dict): if isinstance(load_result, dict):
...@@ -1021,8 +1005,7 @@ def _legacy_load(path, **configs): ...@@ -1021,8 +1005,7 @@ def _legacy_load(path, **configs):
if os.path.isfile(path) or _is_memory_buffer(path): if os.path.isfile(path) or _is_memory_buffer(path):
# we think path is file means this file is created by paddle.save # we think path is file means this file is created by paddle.save
with _open_file_buffer(path, 'rb') as f: with _open_file_buffer(path, 'rb') as f:
load_result = pickle.load(f) if six.PY2 else pickle.load( load_result = pickle.load(f, encoding='latin1')
f, encoding='latin1')
load_result = _pack_loaded_dict(load_result) load_result = _pack_loaded_dict(load_result)
if not config.keep_name_table and "StructuredToParameterName@@" in load_result: if not config.keep_name_table and "StructuredToParameterName@@" in load_result:
del load_result["StructuredToParameterName@@"] del load_result["StructuredToParameterName@@"]
......
...@@ -1296,8 +1296,7 @@ class Model(object): ...@@ -1296,8 +1296,7 @@ class Model(object):
if not os.path.exists(path): if not os.path.exists(path):
return return
with open(path, 'rb') as f: with open(path, 'rb') as f:
return pickle.load(f) if six.PY2 else pickle.load( return pickle.load(f, encoding='latin1')
f, encoding='latin1')
def _check_match(key, param): def _check_match(key, param):
state = param_state.get(key, None) state = param_state.get(key, None)
......
...@@ -21,7 +21,6 @@ from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_t ...@@ -21,7 +21,6 @@ from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_t
from ..fluid.layers.tensor import fill_constant from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils from ..fluid.layers import utils
import numpy as np import numpy as np
import six
# TODO: define functions to manipulate a tensor # TODO: define functions to manipulate a tensor
from ..fluid.layers import cast # noqa: F401 from ..fluid.layers import cast # noqa: F401
from ..fluid.layers import slice # noqa: F401 from ..fluid.layers import slice # noqa: F401
...@@ -1218,10 +1217,7 @@ def tile(x, repeat_times, name=None): ...@@ -1218,10 +1217,7 @@ def tile(x, repeat_times, name=None):
assert len(elem.shape) == 1, ( assert len(elem.shape) == 1, (
'Elements in repeat_times must be 1-D Tensors or integers.') 'Elements in repeat_times must be 1-D Tensors or integers.')
else: else:
if six.PY3: type_tuple = (int, np.int32, np.int64)
type_tuple = (int, np.int32, np.int64)
elif six.PY2:
type_tuple = (int, long, np.int32, np.int64)
assert isinstance(elem, type_tuple), ( assert isinstance(elem, type_tuple), (
'Elements in repeat_times must be 1-D Tensors or integers.') 'Elements in repeat_times must be 1-D Tensors or integers.')
...@@ -1357,10 +1353,7 @@ def broadcast_to(x, shape, name=None): ...@@ -1357,10 +1353,7 @@ def broadcast_to(x, shape, name=None):
assert len(elem.shape) == 1, ( assert len(elem.shape) == 1, (
'Elements in shape must be 1-D Tensors or integers.') 'Elements in shape must be 1-D Tensors or integers.')
else: else:
if six.PY3: type_tuple = (int, np.int32, np.int64)
type_tuple = (int, np.int32, np.int64)
elif six.PY2:
type_tuple = (int, long, np.int32, np.int64)
assert isinstance(elem, type_tuple), ( assert isinstance(elem, type_tuple), (
'Elements in shape must be 1-D Tensors or integers.') 'Elements in shape must be 1-D Tensors or integers.')
...@@ -1447,10 +1440,7 @@ def expand(x, shape, name=None): ...@@ -1447,10 +1440,7 @@ def expand(x, shape, name=None):
assert len(elem.shape) == 1, ( assert len(elem.shape) == 1, (
'Elements in shape must be 1-D Tensors or integers.') 'Elements in shape must be 1-D Tensors or integers.')
else: else:
if six.PY3: type_tuple = (int, np.int32, np.int64)
type_tuple = (int, np.int32, np.int64)
elif six.PY2:
type_tuple = (int, long, np.int32, np.int64)
assert isinstance(elem, type_tuple), ( assert isinstance(elem, type_tuple), (
'Elements in shape must be 1-D Tensors or integers.') 'Elements in shape must be 1-D Tensors or integers.')
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import os import os
import re import re
import six
import sys import sys
import json import json
import glob import glob
...@@ -541,8 +540,7 @@ def find_cuda_home(): ...@@ -541,8 +540,7 @@ def find_cuda_home():
with open(os.devnull, 'w') as devnull: with open(os.devnull, 'w') as devnull:
nvcc_path = subprocess.check_output( nvcc_path = subprocess.check_output(
[which_cmd, 'nvcc'], stderr=devnull) [which_cmd, 'nvcc'], stderr=devnull)
if six.PY3: nvcc_path = nvcc_path.decode()
nvcc_path = nvcc_path.decode()
# Multi CUDA, select the first # Multi CUDA, select the first
nvcc_path = nvcc_path.split('\r\n')[0] nvcc_path = nvcc_path.split('\r\n')[0]
...@@ -580,8 +578,7 @@ def find_rocm_home(): ...@@ -580,8 +578,7 @@ def find_rocm_home():
with open(os.devnull, 'w') as devnull: with open(os.devnull, 'w') as devnull:
hipcc_path = subprocess.check_output( hipcc_path = subprocess.check_output(
[which_cmd, 'hipcc'], stderr=devnull) [which_cmd, 'hipcc'], stderr=devnull)
if six.PY3: hipcc_path = hipcc_path.decode()
hipcc_path = hipcc_path.decode()
hipcc_path = hipcc_path.rstrip('\r\n') hipcc_path = hipcc_path.rstrip('\r\n')
# for example: /opt/rocm/bin/hipcc # for example: /opt/rocm/bin/hipcc
...@@ -652,8 +649,7 @@ def find_clang_cpp_include(compiler='clang'): ...@@ -652,8 +649,7 @@ def find_clang_cpp_include(compiler='clang'):
std_v1_includes = None std_v1_includes = None
try: try:
compiler_version = subprocess.check_output([compiler, "--version"]) compiler_version = subprocess.check_output([compiler, "--version"])
if six.PY3: compiler_version = compiler_version.decode()
compiler_version = compiler_version.decode()
infos = compiler_version.split("\n") infos = compiler_version.split("\n")
for info in infos: for info in infos:
if "InstalledDir" in info: if "InstalledDir" in info:
...@@ -895,13 +891,9 @@ def _load_module_from_file(api_file_path, verbose=False): ...@@ -895,13 +891,9 @@ def _load_module_from_file(api_file_path, verbose=False):
# Unique readable module name to place custom api. # Unique readable module name to place custom api.
log_v('import module from file: {}'.format(api_file_path), verbose) log_v('import module from file: {}'.format(api_file_path), verbose)
ext_name = "_paddle_cpp_extension_" ext_name = "_paddle_cpp_extension_"
if six.PY2: from importlib import machinery
import imp loader = machinery.SourceFileLoader(ext_name, api_file_path)
module = imp.load_source(ext_name, api_file_path) module = loader.load_module()
else:
from importlib import machinery
loader = machinery.SourceFileLoader(ext_name, api_file_path)
module = loader.load_module()
return module return module
...@@ -1005,8 +997,7 @@ def _jit_compile(file_path, verbose=False): ...@@ -1005,8 +997,7 @@ def _jit_compile(file_path, verbose=False):
try: try:
py_version = subprocess.check_output([interpreter, '-V']) py_version = subprocess.check_output([interpreter, '-V'])
if six.PY3: py_version = py_version.decode()
py_version = py_version.decode()
log_v("Using Python interpreter: {}, version: {}".format( log_v("Using Python interpreter: {}, version: {}".format(
interpreter, py_version.strip()), verbose) interpreter, py_version.strip()), verbose)
except Exception: except Exception:
...@@ -1083,8 +1074,7 @@ def check_abi_compatibility(compiler, verbose=False): ...@@ -1083,8 +1074,7 @@ def check_abi_compatibility(compiler, verbose=False):
if not IS_WINDOWS: if not IS_WINDOWS:
cmd_out = subprocess.check_output( cmd_out = subprocess.check_output(
['which', compiler], stderr=subprocess.STDOUT) ['which', compiler], stderr=subprocess.STDOUT)
compiler_path = os.path.realpath(cmd_out.decode() compiler_path = os.path.realpath(cmd_out.decode()).strip()
if six.PY3 else cmd_out).strip()
# if not found any suitable compiler, raise warning # if not found any suitable compiler, raise warning
if not any(name in compiler_path if not any(name in compiler_path
for name in _expected_compiler_current_platform()): for name in _expected_compiler_current_platform()):
...@@ -1104,18 +1094,16 @@ def check_abi_compatibility(compiler, verbose=False): ...@@ -1104,18 +1094,16 @@ def check_abi_compatibility(compiler, verbose=False):
mini_required_version = GCC_MINI_VERSION mini_required_version = GCC_MINI_VERSION
version_info = subprocess.check_output( version_info = subprocess.check_output(
[compiler, '-dumpfullversion', '-dumpversion']) [compiler, '-dumpfullversion', '-dumpversion'])
if six.PY3: version_info = version_info.decode()
version_info = version_info.decode()
version = version_info.strip().split('.') version = version_info.strip().split('.')
elif IS_WINDOWS: elif IS_WINDOWS:
mini_required_version = MSVC_MINI_VERSION mini_required_version = MSVC_MINI_VERSION
compiler_info = subprocess.check_output( compiler_info = subprocess.check_output(
compiler, stderr=subprocess.STDOUT) compiler, stderr=subprocess.STDOUT)
if six.PY3: try:
try: compiler_info = compiler_info.decode('UTF-8')
compiler_info = compiler_info.decode('UTF-8') except UnicodeDecodeError:
except UnicodeDecodeError: compiler_info = compiler_info.decode('gbk')
compiler_info = compiler_info.decode('gbk')
match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.strip()) match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.strip())
if match is not None: if match is not None:
version = match.groups() version = match.groups()
......
...@@ -141,10 +141,7 @@ class Cifar10(Dataset): ...@@ -141,10 +141,7 @@ class Cifar10(Dataset):
if self.flag in each_item.name) if self.flag in each_item.name)
for name in names: for name in names:
if six.PY2: batch = pickle.load(f.extractfile(name), encoding='bytes')
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(f.extractfile(name), encoding='bytes')
data = batch[six.b('data')] data = batch[six.b('data')]
labels = batch.get( labels = batch.get(
......
...@@ -20,7 +20,6 @@ import collections ...@@ -20,7 +20,6 @@ import collections
import sys import sys
import pydoc import pydoc
import hashlib import hashlib
import six
import functools import functools
import platform import platform
...@@ -104,7 +103,7 @@ def visit_member(parent_name, member, func): ...@@ -104,7 +103,7 @@ def visit_member(parent_name, member, func):
def is_primitive(instance): def is_primitive(instance):
int_types = (int, long) if six.PY2 else (int, ) int_types = (int, )
pritimitive_types = int_types + (float, str) pritimitive_types = int_types + (float, str)
if isinstance(instance, pritimitive_types): if isinstance(instance, pritimitive_types):
return True return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册