未验证 提交 0f7187af 编写于 作者: T tianshuo78520a 提交者: GitHub

Del six.PY code2 (#33607)

* del py2 code2

* fix test timeout
上级 79cbc8ea
......@@ -17,12 +17,8 @@ import math
__all__ = []
if six.PY2:
int_type = int
long_type = long # noqa: F821
else:
int_type = int
long_type = int
int_type = int
long_type = int
# str and bytes related functions
......@@ -262,7 +258,4 @@ def get_exception_message(exc):
"""
assert exc is not None
if six.PY2:
return exc.message
else:
return str(exc)
......@@ -62,11 +62,7 @@ def reader_creator(filename, sub_name, cycle=False):
if sub_name in each_item.name)
for name in names:
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(
f.extractfile(name), encoding='bytes')
batch = pickle.load(f.extractfile(name), encoding='bytes')
for item in read_batch(batch):
yield item
......
......@@ -101,8 +101,6 @@ def download(url, module_name, md5sum, save_name=None):
bar = paddle.hapi.progressbar.ProgressBar(
total_iter, name='item')
for data in r.iter_content(chunk_size=chunk_size):
if six.PY2:
data = six.b(data)
f.write(data)
log_index += 1
bar.update(log_index, {})
......
......@@ -132,9 +132,6 @@ def reader_creator(data_file,
file = file.strip()
batch = None
with open(file, 'rb') as f:
if six.PY2:
batch = pickle.load(f)
else:
batch = pickle.load(f, encoding='bytes')
if six.PY3:
......
......@@ -17,12 +17,8 @@ import logging
import six
# NOTE: HTTPServer has a different name in python2 and python3
if six.PY2:
from BaseHTTPServer import HTTPServer
import SimpleHTTPServer
else:
from http.server import HTTPServer
import http.server as SimpleHTTPServer
from http.server import HTTPServer
import http.server as SimpleHTTPServer
import time
import threading
......
......@@ -226,9 +226,6 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
if six.PY2:
images = map(lambda x: x[0].reshape(dshape), data)
if six.PY3:
images = list(map(lambda x: x[0].reshape(dshape), data))
images = np.array(images).astype('float32')
labels = np.array([x[1] for x in data]).astype('int64')
......
......@@ -196,9 +196,6 @@ class QuantInt8ImageClassificationComparisonTest(unittest.TestCase):
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
if six.PY2:
images = map(lambda x: x[0].reshape(dshape), data)
if six.PY3:
images = list(map(lambda x: x[0].reshape(dshape), data))
images = np.array(images).astype('float32')
labels = np.array([x[1] for x in data]).astype('int64')
......
......@@ -27,10 +27,7 @@ from collections import namedtuple
from paddle.fluid.framework import _set_expected_place, _current_expected_place
# NOTE: queue has a different name in python2 and python3
if six.PY2:
import Queue as queue
else:
import queue
import queue
import paddle
from .. import core, layers
......
......@@ -26,10 +26,7 @@ from ..framework import in_dygraph_mode
from .flat import _flatten_batch
# NOTE: queue has a different name in python2 and python3
if six.PY2:
import Queue as queue
else:
import queue
import queue
__all__ = ['get_worker_info']
......
......@@ -19,7 +19,6 @@ import collections
import functools
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
import pickle
import six
from . import learning_rate_scheduler
import warnings
from .. import core
......@@ -194,16 +193,14 @@ def load_dygraph(model_path, **configs):
para_dict = {}
if os.path.exists(params_file_path):
with open(params_file_path, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
para_dict = pickle.load(f, encoding='latin1')
if not config.keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"]
if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
opti_dict = pickle.load(f, encoding='latin1')
else:
# check model path
if not os.path.isdir(model_prefix):
......
......@@ -60,10 +60,7 @@ class BaseNodeVisitor(gast.NodeVisitor):
# imp is deprecated in python3
if six.PY2:
import imp
else:
from importlib.machinery import SourceFileLoader
from importlib.machinery import SourceFileLoader
dygraph_class_to_static_api = {
"CosineDecay": "cosine_decay",
......@@ -491,10 +488,6 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
import_fluid = "import paddle\nimport paddle.fluid as fluid\n"
source = import_fluid + source
if six.PY2:
source = source.encode('utf-8')
f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False)
else:
f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8')
with f:
......@@ -505,9 +498,6 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
atexit.register(lambda: remove_if_exit(f.name))
atexit.register(lambda: remove_if_exit(f.name[:-3] + ".pyc"))
if six.PY2:
module = imp.load_source(module_name, f.name)
else:
module = SourceFileLoader(module_name, f.name).load_module()
func_name = dyfunc.__name__
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
......
......@@ -98,14 +98,6 @@ def create_fill_constant_node(name, value):
func_code += "dtype='float64', value={})".format(value)
return gast.parse(func_code).body[0]
if six.PY2:
if isinstance(value, int):
func_code += "dtype='int32', value={})".format(value)
return gast.parse(func_code).body[0]
if isinstance(value, long):
func_code += "dtype='int64', value={})".format(value)
return gast.parse(func_code).body[0]
else:
if isinstance(value, int):
func_code += "dtype='int64', value={})".format(value)
return gast.parse(func_code).body[0]
......
......@@ -20,7 +20,6 @@ from ..layers.layer_function_generator import OpProtoHolder
from . import no_grad
import numpy as np
import six
import warnings
_supported_int_dtype_ = [
......@@ -121,9 +120,6 @@ def monkey_patch_math_varbase():
assert numel == 1, "only one element variable can be converted to long."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
if six.PY2:
return long(var.numpy().flatten()[0])
else:
return int(var.numpy().flatten()[0])
def _int_(var):
......@@ -141,9 +137,6 @@ def monkey_patch_math_varbase():
assert numel == 1, "only one element variable can be converted to python index."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
if six.PY2:
return long(var.numpy().flatten()[0])
else:
return int(var.numpy().flatten()[0])
@property
......
......@@ -1940,8 +1940,7 @@ def _pickle_loads_mac(path, f):
max_bytes = 2**30
for _ in range(0, file_size, max_bytes):
pickle_bytes += f.read(max_bytes)
load_result = pickle.loads(pickle_bytes) if six.PY2 else pickle.loads(
pickle_bytes, encoding='latin1')
load_result = pickle.loads(pickle_bytes, encoding='latin1')
return load_result
......@@ -2113,8 +2112,7 @@ def load(program, model_path, executor=None, var_list=None):
if sys.platform == 'darwin' and sys.version_info.major == 3:
load_dict = _pickle_loads_mac(parameter_file_name, f)
else:
load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
load_dict = pickle.load(f, encoding='latin1')
load_dict = _pack_loaded_dict(load_dict)
for v in parameter_list:
assert v.name in load_dict, \
......@@ -2135,8 +2133,7 @@ def load(program, model_path, executor=None, var_list=None):
optimizer_var_list, global_scope(), executor._default_executor)
with open(opt_file_name, 'rb') as f:
load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
load_dict = pickle.load(f, encoding='latin1')
for v in optimizer_var_list:
assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format(
......@@ -2297,15 +2294,13 @@ def load_program_state(model_path, var_list=None):
if sys.platform == 'darwin' and sys.version_info.major == 3:
para_dict = _pickle_loads_mac(parameter_file_name, f)
else:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
para_dict = pickle.load(f, encoding='latin1')
para_dict = _pack_loaded_dict(para_dict)
opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name):
with open(opt_file_name, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
opti_dict = pickle.load(f, encoding='latin1')
para_dict.update(opti_dict)
......
......@@ -16,9 +16,7 @@ from __future__ import print_function
import math
import numpy
import six
import warnings
from six.moves import reduce
from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
......@@ -134,11 +132,6 @@ def create_parameter(shape,
"""
check_type(shape, 'shape', (list, tuple, numpy.ndarray), 'create_parameter')
for item in shape:
if six.PY2:
check_type(item, 'item of shape',
(int, long, numpy.uint8, numpy.int8, numpy.int16,
numpy.int32, numpy.int64), 'create_parameter')
else:
check_type(item, 'item of shape',
(int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
numpy.int64), 'create_parameter')
......@@ -194,11 +187,6 @@ def create_global_var(shape,
check_type(shape, 'shape', (list, tuple, numpy.ndarray),
'create_global_var')
for item in shape:
if six.PY2:
check_type(item, 'item of shape',
(int, long, numpy.uint8, numpy.int8, numpy.int16,
numpy.int32, numpy.int64), 'create_global_var')
else:
check_type(item, 'item of shape',
(int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
numpy.int64), 'create_global_var')
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import sys
import signal
import atexit
......@@ -20,10 +19,7 @@ import atexit
from . import core
# NOTE: queue has a different name in python2 and python3
if six.PY2:
import Queue as queue
else:
import queue
import queue
# multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading
......
......@@ -38,10 +38,7 @@ import multiprocessing
import signal
# NOTE: queue has a different name in python2 and python3
if six.PY2:
import Queue as queue
else:
import queue
import queue
# NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process
QUEUE_GET_TIMEOUT = 60
......
......@@ -169,9 +169,6 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor(
))
if six.PY2:
print(pickle.dumps(np.ravel(var).tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
elif save_mode == "DIST":
......@@ -191,9 +188,6 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
continue
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(data))
if six.PY2:
print(pickle.dumps(loss.tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(loss.tolist()))
else:
raise Exception("save_mode must be LOCAL or DIST")
......
......@@ -24,7 +24,6 @@ import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.distributed.fleet.meta_optimizers.sharding as sharding
import os
import six
import sys
import pickle
......@@ -81,9 +80,6 @@ def runtime_main():
exe, dirname, main_program=train_prog, filename=None)
out_losses = []
if six.PY2:
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
......
......@@ -44,11 +44,6 @@ DATA_MD5 = '29ebfc94f11aea9362bbb7f5e9d86b8a'
# Load dictionary.
def load_vocab(filename):
vocab = {}
if six.PY2:
with open(filename, 'r') as f:
for idx, line in enumerate(f):
vocab[line.strip()] = idx
else:
with open(filename, 'r', encoding="utf-8") as f:
for idx, line in enumerate(f):
vocab[line.strip()] = idx
......
......@@ -21,18 +21,10 @@ import sys
import unittest
import gast
import six
import paddle
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
# TODO(liym27): library mock needs to be installed separately in PY2,
# but CI environment has not installed mock yet.
# After discuss with Tian Shuo, now use mock only in PY3, and use it in PY2 after CI installs it.
if six.PY3:
from unittest import mock
# else:
# import mock
from unittest import mock
class TestLoggingUtils(unittest.TestCase):
......@@ -112,7 +104,7 @@ class TestLoggingUtils(unittest.TestCase):
ast_code, "TestTransformer")
def test_log_message(self):
stream = io.BytesIO() if six.PY2 else io.StringIO()
stream = io.StringIO()
log = self.translator_logger.logger
stdout_handler = logging.StreamHandler(stream)
log.addHandler(stdout_handler)
......@@ -122,7 +114,6 @@ class TestLoggingUtils(unittest.TestCase):
log_msg_1 = "test_log_1"
log_msg_2 = "test_log_2"
if six.PY3:
with mock.patch.object(sys, 'stdout', stream):
logging_utils.set_verbosity(1, False)
logging_utils.warn(warn_msg)
......@@ -138,21 +129,19 @@ class TestLoggingUtils(unittest.TestCase):
source_code = "x = 3"
ast_code = gast.parse(source_code)
stream = io.BytesIO() if six.PY2 else io.StringIO()
stream = io.StringIO()
log = self.translator_logger.logger
stdout_handler = logging.StreamHandler(stream)
log.addHandler(stdout_handler)
if six.PY3:
with mock.patch.object(sys, 'stdout', stream):
paddle.jit.set_code_level(1)
logging_utils.log_transformed_code(1, ast_code,
"BasicApiTransformer")
paddle.jit.set_code_level()
logging_utils.log_transformed_code(
logging_utils.LOG_AllTransformer, ast_code,
"All Transformers")
logging_utils.log_transformed_code(logging_utils.LOG_AllTransformer,
ast_code, "All Transformers")
self.assertIn(source_code, stream.getvalue())
......
......@@ -15,7 +15,6 @@
from __future__ import print_function
import gast
import six
import unittest
import numpy as np
......@@ -58,15 +57,6 @@ class TestVariableTransFunc(unittest.TestCase):
source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)"
self.assertEqual(ast_to_source_code(node).strip(), source)
if six.PY2:
node = create_fill_constant_node("c", 214)
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int32', value=214)"
self.assertEqual(ast_to_source_code(node).strip(), source)
node = create_fill_constant_node("d", long(10086))
source = "d = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=10086)"
self.assertEqual(ast_to_source_code(node).strip(), source)
else:
node = create_fill_constant_node("c", 4293)
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)"
self.assertEqual(ast_to_source_code(node).strip(), source)
......
......@@ -18,14 +18,12 @@ import unittest
import time
import argparse
import os
import six
import sys
import subprocess
import traceback
import functools
import pickle
from contextlib import closing
from six import string_types
import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
......@@ -113,9 +111,6 @@ class TestCollectiveRunnerBase(object):
out = exe.run(train_prog,
feed={'tindata': indata},
fetch_list=[result.name])
if six.PY2:
print(pickle.dumps(out))
else:
sys.stdout.buffer.write(pickle.dumps(out))
......
......@@ -18,14 +18,12 @@ import unittest
import time
import argparse
import os
import six
import sys
import subprocess
import traceback
import functools
import pickle
from contextlib import closing
from six import string_types
import paddle
import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen
......@@ -69,9 +67,6 @@ class TestCollectiveAPIRunnerBase(object):
else:
out = self.get_model(train_prog, startup_prog, rank, indata)
#print(out, sys.stderr)
if six.PY2:
print(pickle.dumps(out))
else:
sys.stdout.buffer.write(pickle.dumps(out))
......
......@@ -18,14 +18,12 @@ import unittest
import time
import argparse
import os
import six
import sys
import subprocess
import traceback
import functools
import pickle
from contextlib import closing
from six import string_types
import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
......@@ -37,7 +35,6 @@ class TestCollectiveRunnerBase(object):
"get model should be implemented by child class.")
def wait_server_ready(self, endpoints):
assert not isinstance(endpoints, string_types)
while True:
all_ok = True
not_ready_endpoints = []
......@@ -115,9 +112,6 @@ class TestCollectiveRunnerBase(object):
out = exe.run(train_prog,
feed={'tindata': indata},
fetch_list=[result.name])
if six.PY2:
print(pickle.dumps(out))
else:
sys.stdout.buffer.write(pickle.dumps(out))
......
......@@ -16,142 +16,14 @@ from __future__ import print_function
import unittest
import paddle.compat as cpt
import six
class TestCompatible(unittest.TestCase):
def test_type(self):
if six.PY2:
self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, long)
else:
self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, int)
def test_to_text(self):
# Only support python2.x and python3.x now
self.assertTrue(six.PY2 | six.PY3)
if six.PY2:
# check None
self.assertIsNone(cpt.to_text(None))
# check all string related types
self.assertTrue(isinstance(cpt.to_text(str("")), unicode))
self.assertTrue(isinstance(cpt.to_text(str("123")), unicode))
self.assertTrue(isinstance(cpt.to_text(b""), unicode))
self.assertTrue(isinstance(cpt.to_text(b""), unicode))
self.assertTrue(isinstance(cpt.to_text(u""), unicode))
self.assertTrue(isinstance(cpt.to_text(u""), unicode))
self.assertEqual(u"", cpt.to_text(str("")))
self.assertEqual(u"123", cpt.to_text(str("123")))
self.assertEqual(u"", cpt.to_text(b""))
self.assertEqual(u"123", cpt.to_text(b"123"))
self.assertEqual(u"", cpt.to_text(u""))
self.assertEqual(u"123", cpt.to_text(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u""], l2)
l = ["", "123"]
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123"], l2)
l = ["", b'123', u"321"]
l2 = cpt.to_text(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123", u"321"], l2)
for i in l2:
self.assertTrue(isinstance(i, unicode))
# check list types, inplace
l = [""]
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u""], l2)
l = ["", "123"]
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123", u"321"], l2)
# check set types, not inplace
l = set("")
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(u""), l2)
l = set([b"", b"123"])
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123", u"321"]), l2)
for i in l2:
self.assertTrue(isinstance(i, unicode))
# check set types, inplace
l = set("")
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(u""), l2)
l = set([b"", b"123"])
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123", u"321"]), l2)
# check dict types, not inplace
l = {"": ""}
l2 = cpt.to_text(l, inplace=False)
self.assertTrue(isinstance(l2, dict))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual({"": ""}, l2)
# check dict types, inplace
l = {"": ""}
l2 = cpt.to_text(l, inplace=True)
self.assertTrue(isinstance(l2, dict))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual({"": ""}, l2)
elif six.PY3:
self.assertIsNone(cpt.to_text(None))
self.assertTrue(isinstance(cpt.to_text(str("")), str))
......@@ -269,113 +141,6 @@ class TestCompatible(unittest.TestCase):
self.assertEqual({"": ""}, l2)
def test_to_bytes(self):
# Only support python2.x and python3.x now
self.assertTrue(six.PY2 | six.PY3)
if six.PY2:
# check None
self.assertIsNone(cpt.to_bytes(None))
# check all string related types
self.assertTrue(isinstance(cpt.to_bytes(str("")), bytes))
self.assertTrue(isinstance(cpt.to_bytes(str("123")), bytes))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
self.assertEqual(b"", cpt.to_bytes(str("")))
self.assertEqual(b"123", cpt.to_bytes(str("123")))
self.assertEqual(b"", cpt.to_bytes(b""))
self.assertEqual(b"123", cpt.to_bytes(b"123"))
self.assertEqual(b"", cpt.to_bytes(u""))
self.assertEqual(b"123", cpt.to_bytes(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b""], l2)
l = ["", "123"]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123"], l2)
l = ["", b'123', u"321"]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123", b"321"], l2)
for i in l2:
self.assertTrue(isinstance(i, bytes))
# check list types, inplace
l = [""]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b""], l2)
l = ["", "123"]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123", b"321"], l2)
# check set types, not inplace
l = set("")
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(b""), l2)
l = set([b"", b"123"])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123", b"321"]), l2)
for i in l2:
self.assertTrue(isinstance(i, bytes))
# check set types, inplace
l = set("")
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(b""), l2)
l = set([b"", b"123"])
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123", b"321"]), l2)
elif six.PY3:
self.assertIsNone(cpt.to_bytes(None))
self.assertTrue(isinstance(cpt.to_bytes(str("")), bytes))
......@@ -500,36 +265,16 @@ class TestCompatible(unittest.TestCase):
def test_get_exception_message(self):
exception_message = "test_message"
self.assertRaises(AssertionError, cpt.get_exception_message, None)
if six.PY2:
self.assertRaises(AttributeError, cpt.get_exception_message,
exception_message)
try:
raise RuntimeError(exception_message)
except Exception as e:
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e)
try:
raise Exception(exception_message)
except Exception as e:
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertIsNotNone(e)
if six.PY3:
try:
raise RuntimeError(exception_message)
except Exception as e:
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertIsNotNone(e)
try:
raise Exception(exception_message)
except Exception as e:
self.assertEqual(exception_message,
cpt.get_exception_message(e))
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertIsNotNone(e)
......
......@@ -44,18 +44,12 @@ DIST_UT_PORT = 0
def print_to_out(out_losses):
if six.PY2:
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
def print_to_err(class_name, log_str):
localtime = time.asctime(time.localtime(time.time()))
print_str = localtime + "\t" + class_name + "\t" + log_str
if six.PY2:
sys.stderr.write(pickle.dumps(print_str))
else:
sys.stderr.buffer.write(pickle.dumps(print_str))
......@@ -151,9 +145,6 @@ class TestDistRunnerBase(object):
print_to_err(type(self).__name__, "run step %d finished" % i)
print_to_err(type(self).__name__, "trainer run finished")
if six.PY2:
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
if args.save_model:
......@@ -251,9 +242,6 @@ class TestDistRunnerBase(object):
print_to_err(type(self).__name__, "trainer run finished")
print_to_err(type(self).__name__, "dist losses: {}".format(out_losses))
if six.PY2:
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
def run_use_fleet_api_trainer(self, args):
......@@ -338,9 +326,6 @@ class TestDistRunnerBase(object):
print_to_err(type(self).__name__, "run step %d finished" % i)
print_to_err(type(self).__name__, "trainer run finished")
if six.PY2:
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
if args.save_model:
......
......@@ -18,7 +18,6 @@ import unittest
import paddle
import paddle.fluid as fluid
import numpy as np
import six
import inspect
......@@ -241,9 +240,6 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
a = fluid.dygraph.to_variable(np.array([100.1]))
self.assertTrue(float(a) == 100.1)
self.assertTrue(int(a) == 100)
if six.PY2:
self.assertTrue(long(a) == 100)
else:
self.assertTrue(int(a) == 100)
def test_len(self):
......
......@@ -18,7 +18,6 @@ import unittest
import numpy as np
import os
import sys
import six
from io import BytesIO
import paddle
......@@ -38,10 +37,7 @@ SEED = 10
IMAGE_SIZE = 784
CLASS_NUM = 10
if six.PY2:
LARGE_PARAM = 2**2
else:
LARGE_PARAM = 2**26
LARGE_PARAM = 2**26
def random_batch_reader():
......@@ -105,9 +101,6 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
path = os.path.join("test_paddle_save_load_large_param_save",
"layer.pdparams")
if six.PY2:
protocol = 2
else:
protocol = 4
paddle.save(save_dict, path, protocol=protocol)
dict_load = paddle.load(path)
......@@ -926,9 +919,6 @@ class TestSaveLoadProgram(unittest.TestCase):
class TestSaveLoadLayer(unittest.TestCase):
def test_save_load_layer(self):
if six.PY2:
return
paddle.disable_static()
inps = paddle.randn([1, IMAGE_SIZE], dtype='float32')
layer1 = LinearNet()
......
......@@ -21,15 +21,10 @@ import paddle.fluid.framework as framework
from test_imperative_base import new_program_scope
import numpy as np
import six
import pickle
import os
# Python2.x no longer supports saving and loading large parameters.
if six.PY2:
LARGE_PARAM = 2
else:
LARGE_PARAM = 2**26
LARGE_PARAM = 2**26
class TestStaticSaveLoadLargeParameters(unittest.TestCase):
......@@ -59,9 +54,6 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
path = os.path.join("test_static_save_load_large_param",
"static_save")
if six.PY2:
protocol = 2
else:
protocol = 4
paddle.fluid.save(prog, path, pickle_protocol=protocol)
# set var to zero
......
......@@ -15,7 +15,6 @@
import numpy as np
import paddle
import paddle.fluid as fluid
import six
import unittest
import paddle.nn as nn
import os
......@@ -50,9 +49,6 @@ class TestTracedLayerErrMsg(unittest.TestCase):
self.feature_size = 3
self.fc_size = 2
self.layer = self._train_simple_net()
if six.PY2:
self.type_str = 'type'
else:
self.type_str = 'class'
def test_trace_err(self):
......
......@@ -192,7 +192,7 @@ class TestVarBase(unittest.TestCase):
x = paddle.to_tensor(1, dtype='int64')
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), long if six.PY2 else int))
self.assertTrue(isinstance(x.item(), int))
x = paddle.to_tensor(True)
self.assertEqual(x.item(), True)
......
......@@ -17,14 +17,10 @@ from __future__ import print_function
import os
import collections
import pickle
import six
import warnings
import sys
import numpy as np
if not six.PY2:
import copyreg
import copyreg
import paddle
# deprecated module import
......@@ -295,11 +291,6 @@ def _pickle_save(obj, f, protocol):
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
if six.PY2:
add_dispatch_table()
pickle_bytes = pickle.dump(obj, f, protocol)
pop_dispatch_table()
else:
pickler = pickle.Pickler(f, protocol)
pickler.dispatch_table = copyreg.dispatch_table.copy()
......@@ -359,9 +350,6 @@ def _transformed_from_varbase(obj):
# In paddle2.1 version, VarBase is saved as tuple(tensor.name, tensor.numpy()).
# When executing paddle.load, use this function to determine whether to restore to VarBase/LoDTensor.
if isinstance(obj, tuple) and len(obj) == 2:
if six.PY2:
name_types = (str, unicode)
else:
name_types = str
if isinstance(obj[0], name_types) and isinstance(obj[1], np.ndarray):
return True
......@@ -947,9 +935,6 @@ def load(path, **configs):
if _is_memory_buffer(path) or os.path.isfile(path):
config = _parse_load_config(configs)
if six.PY2:
exception_type = KeyError
else:
exception_type = pickle.UnpicklingError
try:
with _open_file_buffer(path, 'rb') as f:
......@@ -959,8 +944,7 @@ def load(path, **configs):
) and sys.platform == 'darwin' and sys.version_info.major == 3:
load_result = _pickle_loads_mac(path, f)
else:
load_result = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
load_result = pickle.load(f, encoding='latin1')
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
if isinstance(load_result, dict):
......@@ -1021,8 +1005,7 @@ def _legacy_load(path, **configs):
if os.path.isfile(path) or _is_memory_buffer(path):
# we think path is file means this file is created by paddle.save
with _open_file_buffer(path, 'rb') as f:
load_result = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
load_result = pickle.load(f, encoding='latin1')
load_result = _pack_loaded_dict(load_result)
if not config.keep_name_table and "StructuredToParameterName@@" in load_result:
del load_result["StructuredToParameterName@@"]
......
......@@ -1296,8 +1296,7 @@ class Model(object):
if not os.path.exists(path):
return
with open(path, 'rb') as f:
return pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
return pickle.load(f, encoding='latin1')
def _check_match(key, param):
state = param_state.get(key, None)
......
......@@ -21,7 +21,6 @@ from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_t
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
import numpy as np
import six
# TODO: define functions to manipulate a tensor
from ..fluid.layers import cast # noqa: F401
from ..fluid.layers import slice # noqa: F401
......@@ -1218,10 +1217,7 @@ def tile(x, repeat_times, name=None):
assert len(elem.shape) == 1, (
'Elements in repeat_times must be 1-D Tensors or integers.')
else:
if six.PY3:
type_tuple = (int, np.int32, np.int64)
elif six.PY2:
type_tuple = (int, long, np.int32, np.int64)
assert isinstance(elem, type_tuple), (
'Elements in repeat_times must be 1-D Tensors or integers.')
......@@ -1357,10 +1353,7 @@ def broadcast_to(x, shape, name=None):
assert len(elem.shape) == 1, (
'Elements in shape must be 1-D Tensors or integers.')
else:
if six.PY3:
type_tuple = (int, np.int32, np.int64)
elif six.PY2:
type_tuple = (int, long, np.int32, np.int64)
assert isinstance(elem, type_tuple), (
'Elements in shape must be 1-D Tensors or integers.')
......@@ -1447,10 +1440,7 @@ def expand(x, shape, name=None):
assert len(elem.shape) == 1, (
'Elements in shape must be 1-D Tensors or integers.')
else:
if six.PY3:
type_tuple = (int, np.int32, np.int64)
elif six.PY2:
type_tuple = (int, long, np.int32, np.int64)
assert isinstance(elem, type_tuple), (
'Elements in shape must be 1-D Tensors or integers.')
......
......@@ -14,7 +14,6 @@
import os
import re
import six
import sys
import json
import glob
......@@ -541,7 +540,6 @@ def find_cuda_home():
with open(os.devnull, 'w') as devnull:
nvcc_path = subprocess.check_output(
[which_cmd, 'nvcc'], stderr=devnull)
if six.PY3:
nvcc_path = nvcc_path.decode()
# Multi CUDA, select the first
nvcc_path = nvcc_path.split('\r\n')[0]
......@@ -580,7 +578,6 @@ def find_rocm_home():
with open(os.devnull, 'w') as devnull:
hipcc_path = subprocess.check_output(
[which_cmd, 'hipcc'], stderr=devnull)
if six.PY3:
hipcc_path = hipcc_path.decode()
hipcc_path = hipcc_path.rstrip('\r\n')
......@@ -652,7 +649,6 @@ def find_clang_cpp_include(compiler='clang'):
std_v1_includes = None
try:
compiler_version = subprocess.check_output([compiler, "--version"])
if six.PY3:
compiler_version = compiler_version.decode()
infos = compiler_version.split("\n")
for info in infos:
......@@ -895,10 +891,6 @@ def _load_module_from_file(api_file_path, verbose=False):
# Unique readable module name to place custom api.
log_v('import module from file: {}'.format(api_file_path), verbose)
ext_name = "_paddle_cpp_extension_"
if six.PY2:
import imp
module = imp.load_source(ext_name, api_file_path)
else:
from importlib import machinery
loader = machinery.SourceFileLoader(ext_name, api_file_path)
module = loader.load_module()
......@@ -1005,7 +997,6 @@ def _jit_compile(file_path, verbose=False):
try:
py_version = subprocess.check_output([interpreter, '-V'])
if six.PY3:
py_version = py_version.decode()
log_v("Using Python interpreter: {}, version: {}".format(
interpreter, py_version.strip()), verbose)
......@@ -1083,8 +1074,7 @@ def check_abi_compatibility(compiler, verbose=False):
if not IS_WINDOWS:
cmd_out = subprocess.check_output(
['which', compiler], stderr=subprocess.STDOUT)
compiler_path = os.path.realpath(cmd_out.decode()
if six.PY3 else cmd_out).strip()
compiler_path = os.path.realpath(cmd_out.decode()).strip()
# if not found any suitable compiler, raise warning
if not any(name in compiler_path
for name in _expected_compiler_current_platform()):
......@@ -1104,14 +1094,12 @@ def check_abi_compatibility(compiler, verbose=False):
mini_required_version = GCC_MINI_VERSION
version_info = subprocess.check_output(
[compiler, '-dumpfullversion', '-dumpversion'])
if six.PY3:
version_info = version_info.decode()
version = version_info.strip().split('.')
elif IS_WINDOWS:
mini_required_version = MSVC_MINI_VERSION
compiler_info = subprocess.check_output(
compiler, stderr=subprocess.STDOUT)
if six.PY3:
try:
compiler_info = compiler_info.decode('UTF-8')
except UnicodeDecodeError:
......
......@@ -141,9 +141,6 @@ class Cifar10(Dataset):
if self.flag in each_item.name)
for name in names:
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(f.extractfile(name), encoding='bytes')
data = batch[six.b('data')]
......
......@@ -20,7 +20,6 @@ import collections
import sys
import pydoc
import hashlib
import six
import functools
import platform
......@@ -104,7 +103,7 @@ def visit_member(parent_name, member, func):
def is_primitive(instance):
int_types = (int, long) if six.PY2 else (int, )
int_types = (int, )
pritimitive_types = int_types + (float, str)
if isinstance(instance, pritimitive_types):
return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册