Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0f7187af
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0f7187af
编写于
6月 21, 2021
作者:
T
tianshuo78520a
提交者:
GitHub
6月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Del six.PY code2 (#33607)
* del py2 code2 * fix test timeout
上级
79cbc8ea
变更
38
隐藏空白更改
内联
并排
Showing
38 changed file
with
337 addition
and
813 deletion
+337
-813
python/paddle/compat.py
python/paddle/compat.py
+3
-10
python/paddle/dataset/cifar.py
python/paddle/dataset/cifar.py
+1
-5
python/paddle/dataset/common.py
python/paddle/dataset/common.py
+0
-2
python/paddle/dataset/flowers.py
python/paddle/dataset/flowers.py
+1
-4
python/paddle/distributed/fleet/utils/http_server.py
python/paddle/distributed/fleet/utils/http_server.py
+2
-6
python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py
...slim/tests/quant2_int8_image_classification_comparison.py
+1
-4
python/paddle/fluid/contrib/slim/tests/quant_int8_image_classification_comparison.py
.../slim/tests/quant_int8_image_classification_comparison.py
+1
-4
python/paddle/fluid/dataloader/dataloader_iter.py
python/paddle/fluid/dataloader/dataloader_iter.py
+1
-4
python/paddle/fluid/dataloader/worker.py
python/paddle/fluid/dataloader/worker.py
+1
-4
python/paddle/fluid/dygraph/checkpoint.py
python/paddle/fluid/dygraph/checkpoint.py
+2
-5
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
+4
-14
python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py
...le/fluid/dygraph/dygraph_to_static/variable_trans_func.py
+3
-11
python/paddle/fluid/dygraph/math_op_patch.py
python/paddle/fluid/dygraph/math_op_patch.py
+2
-9
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+5
-10
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+6
-18
python/paddle/fluid/multiprocess_utils.py
python/paddle/fluid/multiprocess_utils.py
+1
-5
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+1
-4
python/paddle/fluid/tests/unittests/dist_save_load.py
python/paddle/fluid/tests/unittests/dist_save_load.py
+2
-8
python/paddle/fluid/tests/unittests/dist_sharding_save.py
python/paddle/fluid/tests/unittests/dist_sharding_save.py
+1
-5
python/paddle/fluid/tests/unittests/dist_text_classification.py
.../paddle/fluid/tests/unittests/dist_text_classification.py
+3
-8
python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py
...d/tests/unittests/dygraph_to_static/test_logging_utils.py
+20
-31
python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py
...s/unittests/dygraph_to_static/test_variable_trans_func.py
+3
-13
python/paddle/fluid/tests/unittests/npu/test_collective_base_npu.py
...dle/fluid/tests/unittests/npu/test_collective_base_npu.py
+1
-6
python/paddle/fluid/tests/unittests/test_collective_api_base.py
.../paddle/fluid/tests/unittests/test_collective_api_base.py
+1
-6
python/paddle/fluid/tests/unittests/test_collective_base.py
python/paddle/fluid/tests/unittests/test_collective_base.py
+1
-7
python/paddle/fluid/tests/unittests/test_compat.py
python/paddle/fluid/tests/unittests/test_compat.py
+227
-482
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+5
-20
python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py
...ddle/fluid/tests/unittests/test_math_op_patch_var_base.py
+1
-5
python/paddle/fluid/tests/unittests/test_paddle_save_load.py
python/paddle/fluid/tests/unittests/test_paddle_save_load.py
+2
-12
python/paddle/fluid/tests/unittests/test_static_save_load_large.py
...ddle/fluid/tests/unittests/test_static_save_load_large.py
+2
-10
python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py
...paddle/fluid/tests/unittests/test_traced_layer_err_msg.py
+1
-5
python/paddle/fluid/tests/unittests/test_var_base.py
python/paddle/fluid/tests/unittests/test_var_base.py
+1
-1
python/paddle/framework/io.py
python/paddle/framework/io.py
+12
-29
python/paddle/hapi/model.py
python/paddle/hapi/model.py
+1
-2
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+3
-13
python/paddle/utils/cpp_extension/extension_utils.py
python/paddle/utils/cpp_extension/extension_utils.py
+13
-25
python/paddle/vision/datasets/cifar.py
python/paddle/vision/datasets/cifar.py
+1
-4
tools/count_api_without_core_ops.py
tools/count_api_without_core_ops.py
+1
-2
未找到文件。
python/paddle/compat.py
浏览文件 @
0f7187af
...
@@ -17,12 +17,8 @@ import math
...
@@ -17,12 +17,8 @@ import math
__all__
=
[]
__all__
=
[]
if
six
.
PY2
:
int_type
=
int
int_type
=
int
long_type
=
int
long_type
=
long
# noqa: F821
else
:
int_type
=
int
long_type
=
int
# str and bytes related functions
# str and bytes related functions
...
@@ -262,7 +258,4 @@ def get_exception_message(exc):
...
@@ -262,7 +258,4 @@ def get_exception_message(exc):
"""
"""
assert
exc
is
not
None
assert
exc
is
not
None
if
six
.
PY2
:
return
str
(
exc
)
return
exc
.
message
else
:
return
str
(
exc
)
python/paddle/dataset/cifar.py
浏览文件 @
0f7187af
...
@@ -62,11 +62,7 @@ def reader_creator(filename, sub_name, cycle=False):
...
@@ -62,11 +62,7 @@ def reader_creator(filename, sub_name, cycle=False):
if
sub_name
in
each_item
.
name
)
if
sub_name
in
each_item
.
name
)
for
name
in
names
:
for
name
in
names
:
if
six
.
PY2
:
batch
=
pickle
.
load
(
f
.
extractfile
(
name
),
encoding
=
'bytes'
)
batch
=
pickle
.
load
(
f
.
extractfile
(
name
))
else
:
batch
=
pickle
.
load
(
f
.
extractfile
(
name
),
encoding
=
'bytes'
)
for
item
in
read_batch
(
batch
):
for
item
in
read_batch
(
batch
):
yield
item
yield
item
...
...
python/paddle/dataset/common.py
浏览文件 @
0f7187af
...
@@ -101,8 +101,6 @@ def download(url, module_name, md5sum, save_name=None):
...
@@ -101,8 +101,6 @@ def download(url, module_name, md5sum, save_name=None):
bar
=
paddle
.
hapi
.
progressbar
.
ProgressBar
(
bar
=
paddle
.
hapi
.
progressbar
.
ProgressBar
(
total_iter
,
name
=
'item'
)
total_iter
,
name
=
'item'
)
for
data
in
r
.
iter_content
(
chunk_size
=
chunk_size
):
for
data
in
r
.
iter_content
(
chunk_size
=
chunk_size
):
if
six
.
PY2
:
data
=
six
.
b
(
data
)
f
.
write
(
data
)
f
.
write
(
data
)
log_index
+=
1
log_index
+=
1
bar
.
update
(
log_index
,
{})
bar
.
update
(
log_index
,
{})
...
...
python/paddle/dataset/flowers.py
浏览文件 @
0f7187af
...
@@ -132,10 +132,7 @@ def reader_creator(data_file,
...
@@ -132,10 +132,7 @@ def reader_creator(data_file,
file
=
file
.
strip
()
file
=
file
.
strip
()
batch
=
None
batch
=
None
with
open
(
file
,
'rb'
)
as
f
:
with
open
(
file
,
'rb'
)
as
f
:
if
six
.
PY2
:
batch
=
pickle
.
load
(
f
,
encoding
=
'bytes'
)
batch
=
pickle
.
load
(
f
)
else
:
batch
=
pickle
.
load
(
f
,
encoding
=
'bytes'
)
if
six
.
PY3
:
if
six
.
PY3
:
batch
=
cpt
.
to_text
(
batch
)
batch
=
cpt
.
to_text
(
batch
)
...
...
python/paddle/distributed/fleet/utils/http_server.py
浏览文件 @
0f7187af
...
@@ -17,12 +17,8 @@ import logging
...
@@ -17,12 +17,8 @@ import logging
import
six
import
six
# NOTE: HTTPServer has a different name in python2 and python3
# NOTE: HTTPServer has a different name in python2 and python3
if
six
.
PY2
:
from
http.server
import
HTTPServer
from
BaseHTTPServer
import
HTTPServer
import
http.server
as
SimpleHTTPServer
import
SimpleHTTPServer
else
:
from
http.server
import
HTTPServer
import
http.server
as
SimpleHTTPServer
import
time
import
time
import
threading
import
threading
...
...
python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py
浏览文件 @
0f7187af
...
@@ -226,10 +226,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -226,10 +226,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
if
iters
==
skip_batch_num
:
if
iters
==
skip_batch_num
:
total_samples
=
0
total_samples
=
0
infer_start_time
=
time
.
time
()
infer_start_time
=
time
.
time
()
if
six
.
PY2
:
images
=
list
(
map
(
lambda
x
:
x
[
0
].
reshape
(
dshape
),
data
))
images
=
map
(
lambda
x
:
x
[
0
].
reshape
(
dshape
),
data
)
if
six
.
PY3
:
images
=
list
(
map
(
lambda
x
:
x
[
0
].
reshape
(
dshape
),
data
))
images
=
np
.
array
(
images
).
astype
(
'float32'
)
images
=
np
.
array
(
images
).
astype
(
'float32'
)
labels
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
)
labels
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
)
...
...
python/paddle/fluid/contrib/slim/tests/quant_int8_image_classification_comparison.py
浏览文件 @
0f7187af
...
@@ -196,10 +196,7 @@ class QuantInt8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -196,10 +196,7 @@ class QuantInt8ImageClassificationComparisonTest(unittest.TestCase):
if
iters
==
skip_batch_num
:
if
iters
==
skip_batch_num
:
total_samples
=
0
total_samples
=
0
infer_start_time
=
time
.
time
()
infer_start_time
=
time
.
time
()
if
six
.
PY2
:
images
=
list
(
map
(
lambda
x
:
x
[
0
].
reshape
(
dshape
),
data
))
images
=
map
(
lambda
x
:
x
[
0
].
reshape
(
dshape
),
data
)
if
six
.
PY3
:
images
=
list
(
map
(
lambda
x
:
x
[
0
].
reshape
(
dshape
),
data
))
images
=
np
.
array
(
images
).
astype
(
'float32'
)
images
=
np
.
array
(
images
).
astype
(
'float32'
)
labels
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
)
labels
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
)
...
...
python/paddle/fluid/dataloader/dataloader_iter.py
浏览文件 @
0f7187af
...
@@ -27,10 +27,7 @@ from collections import namedtuple
...
@@ -27,10 +27,7 @@ from collections import namedtuple
from
paddle.fluid.framework
import
_set_expected_place
,
_current_expected_place
from
paddle.fluid.framework
import
_set_expected_place
,
_current_expected_place
# NOTE: queue has a different name in python2 and python3
# NOTE: queue has a different name in python2 and python3
if
six
.
PY2
:
import
queue
import
Queue
as
queue
else
:
import
queue
import
paddle
import
paddle
from
..
import
core
,
layers
from
..
import
core
,
layers
...
...
python/paddle/fluid/dataloader/worker.py
浏览文件 @
0f7187af
...
@@ -26,10 +26,7 @@ from ..framework import in_dygraph_mode
...
@@ -26,10 +26,7 @@ from ..framework import in_dygraph_mode
from
.flat
import
_flatten_batch
from
.flat
import
_flatten_batch
# NOTE: queue has a different name in python2 and python3
# NOTE: queue has a different name in python2 and python3
if
six
.
PY2
:
import
queue
import
Queue
as
queue
else
:
import
queue
__all__
=
[
'get_worker_info'
]
__all__
=
[
'get_worker_info'
]
...
...
python/paddle/fluid/dygraph/checkpoint.py
浏览文件 @
0f7187af
...
@@ -19,7 +19,6 @@ import collections
...
@@ -19,7 +19,6 @@ import collections
import
functools
import
functools
from
..framework
import
Variable
,
default_main_program
,
in_dygraph_mode
,
dygraph_only
,
Parameter
,
ParamBase
,
_varbase_creator
,
_dygraph_tracer
from
..framework
import
Variable
,
default_main_program
,
in_dygraph_mode
,
dygraph_only
,
Parameter
,
ParamBase
,
_varbase_creator
,
_dygraph_tracer
import
pickle
import
pickle
import
six
from
.
import
learning_rate_scheduler
from
.
import
learning_rate_scheduler
import
warnings
import
warnings
from
..
import
core
from
..
import
core
...
@@ -194,16 +193,14 @@ def load_dygraph(model_path, **configs):
...
@@ -194,16 +193,14 @@ def load_dygraph(model_path, **configs):
para_dict
=
{}
para_dict
=
{}
if
os
.
path
.
exists
(
params_file_path
):
if
os
.
path
.
exists
(
params_file_path
):
with
open
(
params_file_path
,
'rb'
)
as
f
:
with
open
(
params_file_path
,
'rb'
)
as
f
:
para_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
para_dict
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
f
,
encoding
=
'latin1'
)
if
not
config
.
keep_name_table
and
"StructuredToParameterName@@"
in
para_dict
:
if
not
config
.
keep_name_table
and
"StructuredToParameterName@@"
in
para_dict
:
del
para_dict
[
"StructuredToParameterName@@"
]
del
para_dict
[
"StructuredToParameterName@@"
]
if
os
.
path
.
exists
(
opti_file_path
):
if
os
.
path
.
exists
(
opti_file_path
):
with
open
(
opti_file_path
,
'rb'
)
as
f
:
with
open
(
opti_file_path
,
'rb'
)
as
f
:
opti_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
opti_dict
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
f
,
encoding
=
'latin1'
)
else
:
else
:
# check model path
# check model path
if
not
os
.
path
.
isdir
(
model_prefix
):
if
not
os
.
path
.
isdir
(
model_prefix
):
...
...
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
浏览文件 @
0f7187af
...
@@ -60,10 +60,7 @@ class BaseNodeVisitor(gast.NodeVisitor):
...
@@ -60,10 +60,7 @@ class BaseNodeVisitor(gast.NodeVisitor):
# imp is deprecated in python3
# imp is deprecated in python3
if
six
.
PY2
:
from
importlib.machinery
import
SourceFileLoader
import
imp
else
:
from
importlib.machinery
import
SourceFileLoader
dygraph_class_to_static_api
=
{
dygraph_class_to_static_api
=
{
"CosineDecay"
:
"cosine_decay"
,
"CosineDecay"
:
"cosine_decay"
,
...
@@ -491,12 +488,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
...
@@ -491,12 +488,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
import_fluid
=
"import paddle
\n
import paddle.fluid as fluid
\n
"
import_fluid
=
"import paddle
\n
import paddle.fluid as fluid
\n
"
source
=
import_fluid
+
source
source
=
import_fluid
+
source
if
six
.
PY2
:
f
=
tempfile
.
NamedTemporaryFile
(
source
=
source
.
encode
(
'utf-8'
)
mode
=
'w'
,
suffix
=
'.py'
,
delete
=
False
,
encoding
=
'utf-8'
)
f
=
tempfile
.
NamedTemporaryFile
(
mode
=
'w'
,
suffix
=
'.py'
,
delete
=
False
)
else
:
f
=
tempfile
.
NamedTemporaryFile
(
mode
=
'w'
,
suffix
=
'.py'
,
delete
=
False
,
encoding
=
'utf-8'
)
with
f
:
with
f
:
module_name
=
os
.
path
.
basename
(
f
.
name
[:
-
3
])
module_name
=
os
.
path
.
basename
(
f
.
name
[:
-
3
])
f
.
write
(
source
)
f
.
write
(
source
)
...
@@ -505,10 +498,7 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
...
@@ -505,10 +498,7 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
atexit
.
register
(
lambda
:
remove_if_exit
(
f
.
name
))
atexit
.
register
(
lambda
:
remove_if_exit
(
f
.
name
))
atexit
.
register
(
lambda
:
remove_if_exit
(
f
.
name
[:
-
3
]
+
".pyc"
))
atexit
.
register
(
lambda
:
remove_if_exit
(
f
.
name
[:
-
3
]
+
".pyc"
))
if
six
.
PY2
:
module
=
SourceFileLoader
(
module_name
,
f
.
name
).
load_module
()
module
=
imp
.
load_source
(
module_name
,
f
.
name
)
else
:
module
=
SourceFileLoader
(
module_name
,
f
.
name
).
load_module
()
func_name
=
dyfunc
.
__name__
func_name
=
dyfunc
.
__name__
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# through 'func_name'. So set the special function name '__i_m_p_l__'.
# through 'func_name'. So set the special function name '__i_m_p_l__'.
...
...
python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py
浏览文件 @
0f7187af
...
@@ -98,17 +98,9 @@ def create_fill_constant_node(name, value):
...
@@ -98,17 +98,9 @@ def create_fill_constant_node(name, value):
func_code
+=
"dtype='float64', value={})"
.
format
(
value
)
func_code
+=
"dtype='float64', value={})"
.
format
(
value
)
return
gast
.
parse
(
func_code
).
body
[
0
]
return
gast
.
parse
(
func_code
).
body
[
0
]
if
six
.
PY2
:
if
isinstance
(
value
,
int
):
if
isinstance
(
value
,
int
):
func_code
+=
"dtype='int64', value={})"
.
format
(
value
)
func_code
+=
"dtype='int32', value={})"
.
format
(
value
)
return
gast
.
parse
(
func_code
).
body
[
0
]
return
gast
.
parse
(
func_code
).
body
[
0
]
if
isinstance
(
value
,
long
):
func_code
+=
"dtype='int64', value={})"
.
format
(
value
)
return
gast
.
parse
(
func_code
).
body
[
0
]
else
:
if
isinstance
(
value
,
int
):
func_code
+=
"dtype='int64', value={})"
.
format
(
value
)
return
gast
.
parse
(
func_code
).
body
[
0
]
def
to_static_variable
(
x
):
def
to_static_variable
(
x
):
...
...
python/paddle/fluid/dygraph/math_op_patch.py
浏览文件 @
0f7187af
...
@@ -20,7 +20,6 @@ from ..layers.layer_function_generator import OpProtoHolder
...
@@ -20,7 +20,6 @@ from ..layers.layer_function_generator import OpProtoHolder
from
.
import
no_grad
from
.
import
no_grad
import
numpy
as
np
import
numpy
as
np
import
six
import
warnings
import
warnings
_supported_int_dtype_
=
[
_supported_int_dtype_
=
[
...
@@ -121,10 +120,7 @@ def monkey_patch_math_varbase():
...
@@ -121,10 +120,7 @@ def monkey_patch_math_varbase():
assert
numel
==
1
,
"only one element variable can be converted to long."
assert
numel
==
1
,
"only one element variable can be converted to long."
tensor
=
var
.
value
().
get_tensor
()
tensor
=
var
.
value
().
get_tensor
()
assert
tensor
.
_is_initialized
(),
"variable's tensor is not initialized"
assert
tensor
.
_is_initialized
(),
"variable's tensor is not initialized"
if
six
.
PY2
:
return
int
(
var
.
numpy
().
flatten
()[
0
])
return
long
(
var
.
numpy
().
flatten
()[
0
])
else
:
return
int
(
var
.
numpy
().
flatten
()[
0
])
def
_int_
(
var
):
def
_int_
(
var
):
numel
=
np
.
prod
(
var
.
shape
)
numel
=
np
.
prod
(
var
.
shape
)
...
@@ -141,10 +137,7 @@ def monkey_patch_math_varbase():
...
@@ -141,10 +137,7 @@ def monkey_patch_math_varbase():
assert
numel
==
1
,
"only one element variable can be converted to python index."
assert
numel
==
1
,
"only one element variable can be converted to python index."
tensor
=
var
.
value
().
get_tensor
()
tensor
=
var
.
value
().
get_tensor
()
assert
tensor
.
_is_initialized
(),
"variable's tensor is not initialized"
assert
tensor
.
_is_initialized
(),
"variable's tensor is not initialized"
if
six
.
PY2
:
return
int
(
var
.
numpy
().
flatten
()[
0
])
return
long
(
var
.
numpy
().
flatten
()[
0
])
else
:
return
int
(
var
.
numpy
().
flatten
()[
0
])
@
property
@
property
def
_ndim_
(
var
):
def
_ndim_
(
var
):
...
...
python/paddle/fluid/io.py
浏览文件 @
0f7187af
...
@@ -1940,8 +1940,7 @@ def _pickle_loads_mac(path, f):
...
@@ -1940,8 +1940,7 @@ def _pickle_loads_mac(path, f):
max_bytes
=
2
**
30
max_bytes
=
2
**
30
for
_
in
range
(
0
,
file_size
,
max_bytes
):
for
_
in
range
(
0
,
file_size
,
max_bytes
):
pickle_bytes
+=
f
.
read
(
max_bytes
)
pickle_bytes
+=
f
.
read
(
max_bytes
)
load_result
=
pickle
.
loads
(
pickle_bytes
)
if
six
.
PY2
else
pickle
.
loads
(
load_result
=
pickle
.
loads
(
pickle_bytes
,
encoding
=
'latin1'
)
pickle_bytes
,
encoding
=
'latin1'
)
return
load_result
return
load_result
...
@@ -2113,8 +2112,7 @@ def load(program, model_path, executor=None, var_list=None):
...
@@ -2113,8 +2112,7 @@ def load(program, model_path, executor=None, var_list=None):
if
sys
.
platform
==
'darwin'
and
sys
.
version_info
.
major
==
3
:
if
sys
.
platform
==
'darwin'
and
sys
.
version_info
.
major
==
3
:
load_dict
=
_pickle_loads_mac
(
parameter_file_name
,
f
)
load_dict
=
_pickle_loads_mac
(
parameter_file_name
,
f
)
else
:
else
:
load_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
load_dict
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
f
,
encoding
=
'latin1'
)
load_dict
=
_pack_loaded_dict
(
load_dict
)
load_dict
=
_pack_loaded_dict
(
load_dict
)
for
v
in
parameter_list
:
for
v
in
parameter_list
:
assert
v
.
name
in
load_dict
,
\
assert
v
.
name
in
load_dict
,
\
...
@@ -2135,8 +2133,7 @@ def load(program, model_path, executor=None, var_list=None):
...
@@ -2135,8 +2133,7 @@ def load(program, model_path, executor=None, var_list=None):
optimizer_var_list
,
global_scope
(),
executor
.
_default_executor
)
optimizer_var_list
,
global_scope
(),
executor
.
_default_executor
)
with
open
(
opt_file_name
,
'rb'
)
as
f
:
with
open
(
opt_file_name
,
'rb'
)
as
f
:
load_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
load_dict
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
f
,
encoding
=
'latin1'
)
for
v
in
optimizer_var_list
:
for
v
in
optimizer_var_list
:
assert
v
.
name
in
load_dict
,
\
assert
v
.
name
in
load_dict
,
\
"Can not find [{}] in model file [{}]"
.
format
(
"Can not find [{}] in model file [{}]"
.
format
(
...
@@ -2297,15 +2294,13 @@ def load_program_state(model_path, var_list=None):
...
@@ -2297,15 +2294,13 @@ def load_program_state(model_path, var_list=None):
if
sys
.
platform
==
'darwin'
and
sys
.
version_info
.
major
==
3
:
if
sys
.
platform
==
'darwin'
and
sys
.
version_info
.
major
==
3
:
para_dict
=
_pickle_loads_mac
(
parameter_file_name
,
f
)
para_dict
=
_pickle_loads_mac
(
parameter_file_name
,
f
)
else
:
else
:
para_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
para_dict
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
f
,
encoding
=
'latin1'
)
para_dict
=
_pack_loaded_dict
(
para_dict
)
para_dict
=
_pack_loaded_dict
(
para_dict
)
opt_file_name
=
model_prefix
+
".pdopt"
opt_file_name
=
model_prefix
+
".pdopt"
if
os
.
path
.
exists
(
opt_file_name
):
if
os
.
path
.
exists
(
opt_file_name
):
with
open
(
opt_file_name
,
'rb'
)
as
f
:
with
open
(
opt_file_name
,
'rb'
)
as
f
:
opti_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
opti_dict
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
f
,
encoding
=
'latin1'
)
para_dict
.
update
(
opti_dict
)
para_dict
.
update
(
opti_dict
)
...
...
python/paddle/fluid/layers/tensor.py
浏览文件 @
0f7187af
...
@@ -16,9 +16,7 @@ from __future__ import print_function
...
@@ -16,9 +16,7 @@ from __future__ import print_function
import
math
import
math
import
numpy
import
numpy
import
six
import
warnings
import
warnings
from
six.moves
import
reduce
from
..layer_helper
import
LayerHelper
from
..layer_helper
import
LayerHelper
from
..param_attr
import
ParamAttr
from
..param_attr
import
ParamAttr
...
@@ -134,14 +132,9 @@ def create_parameter(shape,
...
@@ -134,14 +132,9 @@ def create_parameter(shape,
"""
"""
check_type
(
shape
,
'shape'
,
(
list
,
tuple
,
numpy
.
ndarray
),
'create_parameter'
)
check_type
(
shape
,
'shape'
,
(
list
,
tuple
,
numpy
.
ndarray
),
'create_parameter'
)
for
item
in
shape
:
for
item
in
shape
:
if
six
.
PY2
:
check_type
(
item
,
'item of shape'
,
check_type
(
item
,
'item of shape'
,
(
int
,
numpy
.
uint8
,
numpy
.
int8
,
numpy
.
int16
,
numpy
.
int32
,
(
int
,
long
,
numpy
.
uint8
,
numpy
.
int8
,
numpy
.
int16
,
numpy
.
int64
),
'create_parameter'
)
numpy
.
int32
,
numpy
.
int64
),
'create_parameter'
)
else
:
check_type
(
item
,
'item of shape'
,
(
int
,
numpy
.
uint8
,
numpy
.
int8
,
numpy
.
int16
,
numpy
.
int32
,
numpy
.
int64
),
'create_parameter'
)
check_dtype
(
dtype
,
'dtype'
,
[
check_dtype
(
dtype
,
'dtype'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int8'
,
'int16'
,
'int32'
,
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int8'
,
'int16'
,
'int32'
,
...
@@ -194,14 +187,9 @@ def create_global_var(shape,
...
@@ -194,14 +187,9 @@ def create_global_var(shape,
check_type
(
shape
,
'shape'
,
(
list
,
tuple
,
numpy
.
ndarray
),
check_type
(
shape
,
'shape'
,
(
list
,
tuple
,
numpy
.
ndarray
),
'create_global_var'
)
'create_global_var'
)
for
item
in
shape
:
for
item
in
shape
:
if
six
.
PY2
:
check_type
(
item
,
'item of shape'
,
check_type
(
item
,
'item of shape'
,
(
int
,
numpy
.
uint8
,
numpy
.
int8
,
numpy
.
int16
,
numpy
.
int32
,
(
int
,
long
,
numpy
.
uint8
,
numpy
.
int8
,
numpy
.
int16
,
numpy
.
int64
),
'create_global_var'
)
numpy
.
int32
,
numpy
.
int64
),
'create_global_var'
)
else
:
check_type
(
item
,
'item of shape'
,
(
int
,
numpy
.
uint8
,
numpy
.
int8
,
numpy
.
int16
,
numpy
.
int32
,
numpy
.
int64
),
'create_global_var'
)
check_dtype
(
dtype
,
'dtype'
,
[
check_dtype
(
dtype
,
'dtype'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int8'
,
'int16'
,
'int32'
,
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int8'
,
'int16'
,
'int32'
,
...
...
python/paddle/fluid/multiprocess_utils.py
浏览文件 @
0f7187af
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
six
import
sys
import
sys
import
signal
import
signal
import
atexit
import
atexit
...
@@ -20,10 +19,7 @@ import atexit
...
@@ -20,10 +19,7 @@ import atexit
from
.
import
core
from
.
import
core
# NOTE: queue has a different name in python2 and python3
# NOTE: queue has a different name in python2 and python3
if
six
.
PY2
:
import
queue
import
Queue
as
queue
else
:
import
queue
# multi-process worker check indices queue interval, avoid
# multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading
# hanging in subprocess data loading
...
...
python/paddle/fluid/reader.py
浏览文件 @
0f7187af
...
@@ -38,10 +38,7 @@ import multiprocessing
...
@@ -38,10 +38,7 @@ import multiprocessing
import
signal
import
signal
# NOTE: queue has a different name in python2 and python3
# NOTE: queue has a different name in python2 and python3
if
six
.
PY2
:
import
queue
import
Queue
as
queue
else
:
import
queue
# NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process
# NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process
QUEUE_GET_TIMEOUT
=
60
QUEUE_GET_TIMEOUT
=
60
...
...
python/paddle/fluid/tests/unittests/dist_save_load.py
浏览文件 @
0f7187af
...
@@ -169,10 +169,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
...
@@ -169,10 +169,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
var
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
'__fc_b__'
).
get_tensor
(
var
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
'__fc_b__'
).
get_tensor
(
))
))
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
np
.
ravel
(
var
).
tolist
()))
print
(
pickle
.
dumps
(
np
.
ravel
(
var
).
tolist
()))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
np
.
ravel
(
var
).
tolist
()))
elif
save_mode
==
"DIST"
:
elif
save_mode
==
"DIST"
:
skip_steps
=
int
(
os
.
getenv
(
"SKIP_STEPS"
))
skip_steps
=
int
(
os
.
getenv
(
"SKIP_STEPS"
))
...
@@ -191,10 +188,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
...
@@ -191,10 +188,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
continue
continue
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
feed
=
feeder
.
feed
(
data
))
feed
=
feeder
.
feed
(
data
))
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
loss
.
tolist
()))
print
(
pickle
.
dumps
(
loss
.
tolist
()))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
loss
.
tolist
()))
else
:
else
:
raise
Exception
(
"save_mode must be LOCAL or DIST"
)
raise
Exception
(
"save_mode must be LOCAL or DIST"
)
...
...
python/paddle/fluid/tests/unittests/dist_sharding_save.py
浏览文件 @
0f7187af
...
@@ -24,7 +24,6 @@ import paddle.distributed.fleet.base.role_maker as role_maker
...
@@ -24,7 +24,6 @@ import paddle.distributed.fleet.base.role_maker as role_maker
import
paddle.distributed.fleet.meta_optimizers.sharding
as
sharding
import
paddle.distributed.fleet.meta_optimizers.sharding
as
sharding
import
os
import
os
import
six
import
sys
import
sys
import
pickle
import
pickle
...
@@ -81,10 +80,7 @@ def runtime_main():
...
@@ -81,10 +80,7 @@ def runtime_main():
exe
,
dirname
,
main_program
=
train_prog
,
filename
=
None
)
exe
,
dirname
,
main_program
=
train_prog
,
filename
=
None
)
out_losses
=
[]
out_losses
=
[]
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
print
(
pickle
.
dumps
(
out_losses
))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/dist_text_classification.py
浏览文件 @
0f7187af
...
@@ -44,14 +44,9 @@ DATA_MD5 = '29ebfc94f11aea9362bbb7f5e9d86b8a'
...
@@ -44,14 +44,9 @@ DATA_MD5 = '29ebfc94f11aea9362bbb7f5e9d86b8a'
# Load dictionary.
# Load dictionary.
def
load_vocab
(
filename
):
def
load_vocab
(
filename
):
vocab
=
{}
vocab
=
{}
if
six
.
PY2
:
with
open
(
filename
,
'r'
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
filename
,
'r'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
for
idx
,
line
in
enumerate
(
f
):
vocab
[
line
.
strip
()]
=
idx
vocab
[
line
.
strip
()]
=
idx
else
:
with
open
(
filename
,
'r'
,
encoding
=
"utf-8"
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
vocab
[
line
.
strip
()]
=
idx
return
vocab
return
vocab
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py
浏览文件 @
0f7187af
...
@@ -21,18 +21,10 @@ import sys
...
@@ -21,18 +21,10 @@ import sys
import
unittest
import
unittest
import
gast
import
gast
import
six
import
paddle
import
paddle
from
paddle.fluid.dygraph.dygraph_to_static
import
logging_utils
from
paddle.fluid.dygraph.dygraph_to_static
import
logging_utils
from
unittest
import
mock
# TODO(liym27): library mock needs to be installed separately in PY2,
# but CI environment has not installed mock yet.
# After discuss with Tian Shuo, now use mock only in PY3, and use it in PY2 after CI installs it.
if
six
.
PY3
:
from
unittest
import
mock
# else:
# import mock
class
TestLoggingUtils
(
unittest
.
TestCase
):
class
TestLoggingUtils
(
unittest
.
TestCase
):
...
@@ -112,7 +104,7 @@ class TestLoggingUtils(unittest.TestCase):
...
@@ -112,7 +104,7 @@ class TestLoggingUtils(unittest.TestCase):
ast_code
,
"TestTransformer"
)
ast_code
,
"TestTransformer"
)
def
test_log_message
(
self
):
def
test_log_message
(
self
):
stream
=
io
.
BytesIO
()
if
six
.
PY2
else
io
.
StringIO
()
stream
=
io
.
StringIO
()
log
=
self
.
translator_logger
.
logger
log
=
self
.
translator_logger
.
logger
stdout_handler
=
logging
.
StreamHandler
(
stream
)
stdout_handler
=
logging
.
StreamHandler
(
stream
)
log
.
addHandler
(
stdout_handler
)
log
.
addHandler
(
stdout_handler
)
...
@@ -122,39 +114,36 @@ class TestLoggingUtils(unittest.TestCase):
...
@@ -122,39 +114,36 @@ class TestLoggingUtils(unittest.TestCase):
log_msg_1
=
"test_log_1"
log_msg_1
=
"test_log_1"
log_msg_2
=
"test_log_2"
log_msg_2
=
"test_log_2"
if
six
.
PY3
:
with
mock
.
patch
.
object
(
sys
,
'stdout'
,
stream
):
with
mock
.
patch
.
object
(
sys
,
'stdout'
,
stream
):
logging_utils
.
set_verbosity
(
1
,
False
)
logging_utils
.
set_verbosity
(
1
,
False
)
logging_utils
.
warn
(
warn_msg
)
logging_utils
.
warn
(
warn_msg
)
logging_utils
.
error
(
error_msg
)
logging_utils
.
error
(
error_msg
)
logging_utils
.
log
(
1
,
log_msg_1
)
logging_utils
.
log
(
1
,
log_msg_1
)
logging_utils
.
log
(
2
,
log_msg_2
)
logging_utils
.
log
(
2
,
log_msg_2
)
result_msg
=
'
\n
'
.
join
(
result_msg
=
'
\n
'
.
join
(
[
warn_msg
,
error_msg
,
"(Level 1) "
+
log_msg_1
,
""
])
[
warn_msg
,
error_msg
,
"(Level 1) "
+
log_msg_1
,
""
])
self
.
assertEqual
(
result_msg
,
stream
.
getvalue
())
self
.
assertEqual
(
result_msg
,
stream
.
getvalue
())
def
test_log_transformed_code
(
self
):
def
test_log_transformed_code
(
self
):
source_code
=
"x = 3"
source_code
=
"x = 3"
ast_code
=
gast
.
parse
(
source_code
)
ast_code
=
gast
.
parse
(
source_code
)
stream
=
io
.
BytesIO
()
if
six
.
PY2
else
io
.
StringIO
()
stream
=
io
.
StringIO
()
log
=
self
.
translator_logger
.
logger
log
=
self
.
translator_logger
.
logger
stdout_handler
=
logging
.
StreamHandler
(
stream
)
stdout_handler
=
logging
.
StreamHandler
(
stream
)
log
.
addHandler
(
stdout_handler
)
log
.
addHandler
(
stdout_handler
)
if
six
.
PY3
:
with
mock
.
patch
.
object
(
sys
,
'stdout'
,
stream
):
with
mock
.
patch
.
object
(
sys
,
'stdout'
,
stream
):
paddle
.
jit
.
set_code_level
(
1
)
paddle
.
jit
.
set_code_level
(
1
)
logging_utils
.
log_transformed_code
(
1
,
ast_code
,
logging_utils
.
log_transformed_code
(
1
,
ast_code
,
"BasicApiTransformer"
)
"BasicApiTransformer"
)
paddle
.
jit
.
set_code_level
()
paddle
.
jit
.
set_code_level
()
logging_utils
.
log_transformed_code
(
logging_utils
.
log_transformed_code
(
logging_utils
.
LOG_AllTransformer
,
logging_utils
.
LOG_AllTransformer
,
ast_code
,
ast_code
,
"All Transformers"
)
"All Transformers"
)
self
.
assertIn
(
source_code
,
stream
.
getvalue
())
self
.
assertIn
(
source_code
,
stream
.
getvalue
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py
浏览文件 @
0f7187af
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
gast
import
gast
import
six
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
...
@@ -58,18 +57,9 @@ class TestVariableTransFunc(unittest.TestCase):
...
@@ -58,18 +57,9 @@ class TestVariableTransFunc(unittest.TestCase):
source
=
"b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)"
source
=
"b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)"
self
.
assertEqual
(
ast_to_source_code
(
node
).
strip
(),
source
)
self
.
assertEqual
(
ast_to_source_code
(
node
).
strip
(),
source
)
if
six
.
PY2
:
node
=
create_fill_constant_node
(
"c"
,
4293
)
node
=
create_fill_constant_node
(
"c"
,
214
)
source
=
"c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)"
source
=
"c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int32', value=214)"
self
.
assertEqual
(
ast_to_source_code
(
node
).
strip
(),
source
)
self
.
assertEqual
(
ast_to_source_code
(
node
).
strip
(),
source
)
node
=
create_fill_constant_node
(
"d"
,
long
(
10086
))
source
=
"d = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=10086)"
self
.
assertEqual
(
ast_to_source_code
(
node
).
strip
(),
source
)
else
:
node
=
create_fill_constant_node
(
"c"
,
4293
)
source
=
"c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)"
self
.
assertEqual
(
ast_to_source_code
(
node
).
strip
(),
source
)
self
.
assertIsNone
(
create_fill_constant_node
(
"e"
,
None
))
self
.
assertIsNone
(
create_fill_constant_node
(
"e"
,
None
))
self
.
assertIsNone
(
create_fill_constant_node
(
"e"
,
[]))
self
.
assertIsNone
(
create_fill_constant_node
(
"e"
,
[]))
...
...
python/paddle/fluid/tests/unittests/npu/test_collective_base_npu.py
浏览文件 @
0f7187af
...
@@ -18,14 +18,12 @@ import unittest
...
@@ -18,14 +18,12 @@ import unittest
import
time
import
time
import
argparse
import
argparse
import
os
import
os
import
six
import
sys
import
sys
import
subprocess
import
subprocess
import
traceback
import
traceback
import
functools
import
functools
import
pickle
import
pickle
from
contextlib
import
closing
from
contextlib
import
closing
from
six
import
string_types
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.unique_name
as
nameGen
import
paddle.fluid.unique_name
as
nameGen
from
paddle.fluid
import
core
from
paddle.fluid
import
core
...
@@ -113,10 +111,7 @@ class TestCollectiveRunnerBase(object):
...
@@ -113,10 +111,7 @@ class TestCollectiveRunnerBase(object):
out
=
exe
.
run
(
train_prog
,
out
=
exe
.
run
(
train_prog
,
feed
=
{
'tindata'
:
indata
},
feed
=
{
'tindata'
:
indata
},
fetch_list
=
[
result
.
name
])
fetch_list
=
[
result
.
name
])
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
))
print
(
pickle
.
dumps
(
out
))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
))
def
runtime_main
(
test_class
,
col_type
,
sub_type
):
def
runtime_main
(
test_class
,
col_type
,
sub_type
):
...
...
python/paddle/fluid/tests/unittests/test_collective_api_base.py
浏览文件 @
0f7187af
...
@@ -18,14 +18,12 @@ import unittest
...
@@ -18,14 +18,12 @@ import unittest
import
time
import
time
import
argparse
import
argparse
import
os
import
os
import
six
import
sys
import
sys
import
subprocess
import
subprocess
import
traceback
import
traceback
import
functools
import
functools
import
pickle
import
pickle
from
contextlib
import
closing
from
contextlib
import
closing
from
six
import
string_types
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.unique_name
as
nameGen
import
paddle.fluid.unique_name
as
nameGen
...
@@ -69,10 +67,7 @@ class TestCollectiveAPIRunnerBase(object):
...
@@ -69,10 +67,7 @@ class TestCollectiveAPIRunnerBase(object):
else
:
else
:
out
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
,
indata
)
out
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
,
indata
)
#print(out, sys.stderr)
#print(out, sys.stderr)
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
))
print
(
pickle
.
dumps
(
out
))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
))
def
runtime_main
(
test_class
,
col_type
):
def
runtime_main
(
test_class
,
col_type
):
...
...
python/paddle/fluid/tests/unittests/test_collective_base.py
浏览文件 @
0f7187af
...
@@ -18,14 +18,12 @@ import unittest
...
@@ -18,14 +18,12 @@ import unittest
import
time
import
time
import
argparse
import
argparse
import
os
import
os
import
six
import
sys
import
sys
import
subprocess
import
subprocess
import
traceback
import
traceback
import
functools
import
functools
import
pickle
import
pickle
from
contextlib
import
closing
from
contextlib
import
closing
from
six
import
string_types
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.unique_name
as
nameGen
import
paddle.fluid.unique_name
as
nameGen
from
paddle.fluid
import
core
from
paddle.fluid
import
core
...
@@ -37,7 +35,6 @@ class TestCollectiveRunnerBase(object):
...
@@ -37,7 +35,6 @@ class TestCollectiveRunnerBase(object):
"get model should be implemented by child class."
)
"get model should be implemented by child class."
)
def
wait_server_ready
(
self
,
endpoints
):
def
wait_server_ready
(
self
,
endpoints
):
assert
not
isinstance
(
endpoints
,
string_types
)
while
True
:
while
True
:
all_ok
=
True
all_ok
=
True
not_ready_endpoints
=
[]
not_ready_endpoints
=
[]
...
@@ -115,10 +112,7 @@ class TestCollectiveRunnerBase(object):
...
@@ -115,10 +112,7 @@ class TestCollectiveRunnerBase(object):
out
=
exe
.
run
(
train_prog
,
out
=
exe
.
run
(
train_prog
,
feed
=
{
'tindata'
:
indata
},
feed
=
{
'tindata'
:
indata
},
fetch_list
=
[
result
.
name
])
fetch_list
=
[
result
.
name
])
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
))
print
(
pickle
.
dumps
(
out
))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
))
def
runtime_main
(
test_class
,
col_type
,
sub_type
):
def
runtime_main
(
test_class
,
col_type
,
sub_type
):
...
...
python/paddle/fluid/tests/unittests/test_compat.py
浏览文件 @
0f7187af
...
@@ -16,465 +16,230 @@ from __future__ import print_function
...
@@ -16,465 +16,230 @@ from __future__ import print_function
import
unittest
import
unittest
import
paddle.compat
as
cpt
import
paddle.compat
as
cpt
import
six
class
TestCompatible
(
unittest
.
TestCase
):
class
TestCompatible
(
unittest
.
TestCase
):
def
test_type
(
self
):
def
test_type
(
self
):
if
six
.
PY2
:
self
.
assertEqual
(
cpt
.
int_type
,
int
)
self
.
assertEqual
(
cpt
.
int_type
,
int
)
self
.
assertEqual
(
cpt
.
long_type
,
int
)
self
.
assertEqual
(
cpt
.
long_type
,
long
)
else
:
self
.
assertEqual
(
cpt
.
int_type
,
int
)
self
.
assertEqual
(
cpt
.
long_type
,
int
)
def
test_to_text
(
self
):
def
test_to_text
(
self
):
# Only support python2.x and python3.x now
self
.
assertIsNone
(
cpt
.
to_text
(
None
))
self
.
assertTrue
(
six
.
PY2
|
six
.
PY3
)
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
str
(
""
)),
str
))
if
six
.
PY2
:
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
str
(
"123"
)),
str
))
# check None
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
b
""
),
str
))
self
.
assertIsNone
(
cpt
.
to_text
(
None
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
b
""
),
str
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
u
""
),
str
))
# check all string related types
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
u
""
),
str
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
str
(
""
)),
unicode
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
str
(
"123"
)),
unicode
))
self
.
assertEqual
(
""
,
cpt
.
to_text
(
str
(
""
)))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
b
""
),
unicode
))
self
.
assertEqual
(
"123"
,
cpt
.
to_text
(
str
(
"123"
)))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
b
""
),
unicode
))
self
.
assertEqual
(
""
,
cpt
.
to_text
(
b
""
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
u
""
),
unicode
))
self
.
assertEqual
(
"123"
,
cpt
.
to_text
(
b
"123"
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
u
""
),
unicode
))
self
.
assertEqual
(
""
,
cpt
.
to_text
(
u
""
))
self
.
assertEqual
(
"123"
,
cpt
.
to_text
(
u
"123"
))
self
.
assertEqual
(
u
""
,
cpt
.
to_text
(
str
(
""
)))
self
.
assertEqual
(
u
"123"
,
cpt
.
to_text
(
str
(
"123"
)))
# check list types, not inplace
self
.
assertEqual
(
u
""
,
cpt
.
to_text
(
b
""
))
l
=
[
""
]
self
.
assertEqual
(
u
"123"
,
cpt
.
to_text
(
b
"123"
))
l2
=
cpt
.
to_text
(
l
)
self
.
assertEqual
(
u
""
,
cpt
.
to_text
(
u
""
))
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertEqual
(
u
"123"
,
cpt
.
to_text
(
u
"123"
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
# check list types, not inplace
self
.
assertEqual
([
""
],
l2
)
l
=
[
""
]
l
=
[
""
,
"123"
]
l2
=
cpt
.
to_text
(
l
)
l2
=
cpt
.
to_text
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
([
u
""
],
l2
)
self
.
assertEqual
([
""
,
"123"
],
l2
)
l
=
[
""
,
"123"
]
l
=
[
""
,
b
"123"
,
u
"321"
]
l2
=
cpt
.
to_text
(
l
)
l2
=
cpt
.
to_text
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
([
u
""
,
u
"123"
],
l2
)
self
.
assertEqual
([
""
,
"123"
,
"321"
],
l2
)
l
=
[
""
,
b
'123'
,
u
"321"
]
l2
=
cpt
.
to_text
(
l
)
# check list types, inplace
self
.
assertTrue
(
isinstance
(
l2
,
list
))
l
=
[
""
]
self
.
assertFalse
(
l
is
l2
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertEqual
([
u
""
,
u
"123"
,
u
"321"
],
l2
)
self
.
assertTrue
(
l
is
l2
)
for
i
in
l2
:
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
i
,
unicode
))
self
.
assertEqual
([
""
],
l2
)
l
=
[
""
,
b
"123"
]
# check list types, inplace
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
l
=
[
""
]
self
.
assertTrue
(
isinstance
(
l2
,
list
))
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
([
""
,
"123"
],
l2
)
self
.
assertEqual
(
l
,
l2
)
l
=
[
""
,
b
"123"
,
u
"321"
]
self
.
assertEqual
([
u
""
],
l2
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
l
=
[
""
,
"123"
]
self
.
assertTrue
(
isinstance
(
l2
,
list
))
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
([
""
,
"123"
,
"321"
],
l2
)
self
.
assertEqual
(
l
,
l2
)
for
i
in
l2
:
self
.
assertEqual
([
u
""
,
u
"123"
],
l2
)
self
.
assertTrue
(
isinstance
(
i
,
str
))
l
=
[
""
,
b
"123"
,
u
"321"
]
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
# check set types, not inplace
self
.
assertTrue
(
isinstance
(
l2
,
list
))
l
=
set
(
""
)
self
.
assertTrue
(
l
is
l2
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertEqual
([
u
""
,
u
"123"
,
u
"321"
],
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
# check set types, not inplace
self
.
assertEqual
(
set
(
""
),
l2
)
l
=
set
(
""
)
l
=
set
([
b
""
,
b
"123"
])
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
(
set
(
u
""
),
l2
)
self
.
assertEqual
(
set
([
""
,
"123"
]),
l2
)
l
=
set
([
b
""
,
b
"123"
])
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
u
""
,
u
"123"
]),
l2
)
self
.
assertEqual
(
set
([
""
,
"123"
,
"321"
]),
l2
)
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
# check set types, inplace
self
.
assertTrue
(
isinstance
(
l2
,
set
))
l
=
set
(
""
)
self
.
assertFalse
(
l
is
l2
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertEqual
(
set
([
u
""
,
u
"123"
,
u
"321"
]),
l2
)
self
.
assertTrue
(
l
is
l2
)
for
i
in
l2
:
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
i
,
unicode
))
self
.
assertEqual
(
set
(
""
),
l2
)
l
=
set
([
b
""
,
b
"123"
])
# check set types, inplace
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
l
=
set
(
""
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
set
([
""
,
"123"
]),
l2
)
self
.
assertEqual
(
l
,
l2
)
l
=
set
([
""
,
b
"123"
,
u
"321"
])
self
.
assertEqual
(
set
(
u
""
),
l2
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
l
=
set
([
b
""
,
b
"123"
])
self
.
assertTrue
(
isinstance
(
l2
,
set
))
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
set
([
""
,
"123"
,
"321"
]),
l2
)
self
.
assertEqual
(
l
,
l2
)
for
i
in
l2
:
self
.
assertEqual
(
set
([
u
""
,
u
"123"
]),
l2
)
self
.
assertTrue
(
isinstance
(
i
,
str
))
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
# check dict types, not inplace
self
.
assertTrue
(
isinstance
(
l2
,
set
))
l
=
{
""
:
""
}
self
.
assertTrue
(
l
is
l2
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
dict
))
self
.
assertEqual
(
set
([
u
""
,
u
"123"
,
u
"321"
]),
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
# check dict types, not inplace
self
.
assertEqual
({
""
:
""
},
l2
)
l
=
{
""
:
""
}
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
# check dict types, inplace
self
.
assertTrue
(
isinstance
(
l2
,
dict
))
l
=
{
""
:
""
}
self
.
assertFalse
(
l
is
l2
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
dict
))
self
.
assertEqual
({
""
:
""
},
l2
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
# check dict types, inplace
self
.
assertEqual
({
""
:
""
},
l2
)
l
=
{
""
:
""
}
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
dict
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
({
""
:
""
},
l2
)
elif
six
.
PY3
:
self
.
assertIsNone
(
cpt
.
to_text
(
None
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
str
(
""
)),
str
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
str
(
"123"
)),
str
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
b
""
),
str
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
b
""
),
str
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
u
""
),
str
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_text
(
u
""
),
str
))
self
.
assertEqual
(
""
,
cpt
.
to_text
(
str
(
""
)))
self
.
assertEqual
(
"123"
,
cpt
.
to_text
(
str
(
"123"
)))
self
.
assertEqual
(
""
,
cpt
.
to_text
(
b
""
))
self
.
assertEqual
(
"123"
,
cpt
.
to_text
(
b
"123"
))
self
.
assertEqual
(
""
,
cpt
.
to_text
(
u
""
))
self
.
assertEqual
(
"123"
,
cpt
.
to_text
(
u
"123"
))
# check list types, not inplace
l
=
[
""
]
l2
=
cpt
.
to_text
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
([
""
],
l2
)
l
=
[
""
,
"123"
]
l2
=
cpt
.
to_text
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
([
""
,
"123"
],
l2
)
l
=
[
""
,
b
"123"
,
u
"321"
]
l2
=
cpt
.
to_text
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
([
""
,
"123"
,
"321"
],
l2
)
# check list types, inplace
l
=
[
""
]
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
([
""
],
l2
)
l
=
[
""
,
b
"123"
]
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
([
""
,
"123"
],
l2
)
l
=
[
""
,
b
"123"
,
u
"321"
]
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
([
""
,
"123"
,
"321"
],
l2
)
for
i
in
l2
:
self
.
assertTrue
(
isinstance
(
i
,
str
))
# check set types, not inplace
l
=
set
(
""
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
(
set
(
""
),
l2
)
l
=
set
([
b
""
,
b
"123"
])
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
""
,
"123"
]),
l2
)
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
""
,
"123"
,
"321"
]),
l2
)
# check set types, inplace
l
=
set
(
""
)
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
(
set
(
""
),
l2
)
l
=
set
([
b
""
,
b
"123"
])
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
""
,
"123"
]),
l2
)
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
""
,
"123"
,
"321"
]),
l2
)
for
i
in
l2
:
self
.
assertTrue
(
isinstance
(
i
,
str
))
# check dict types, not inplace
l
=
{
""
:
""
}
l2
=
cpt
.
to_text
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
dict
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
({
""
:
""
},
l2
)
# check dict types, inplace
l
=
{
""
:
""
}
l2
=
cpt
.
to_text
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
dict
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
({
""
:
""
},
l2
)
def
test_to_bytes
(
self
):
def
test_to_bytes
(
self
):
# Only support python2.x and python3.x now
self
.
assertIsNone
(
cpt
.
to_bytes
(
None
))
self
.
assertTrue
(
six
.
PY2
|
six
.
PY3
)
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
str
(
""
)),
bytes
))
if
six
.
PY2
:
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
str
(
"123"
)),
bytes
))
# check None
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
b
""
),
bytes
))
self
.
assertIsNone
(
cpt
.
to_bytes
(
None
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
b
""
),
bytes
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
u
""
),
bytes
))
# check all string related types
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
u
""
),
bytes
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
str
(
""
)),
bytes
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
str
(
"123"
)),
bytes
))
self
.
assertEqual
(
b
""
,
cpt
.
to_bytes
(
str
(
""
)))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
b
""
),
bytes
))
self
.
assertEqual
(
b
"123"
,
cpt
.
to_bytes
(
str
(
"123"
)))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
b
""
),
bytes
))
self
.
assertEqual
(
b
""
,
cpt
.
to_bytes
(
b
""
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
u
""
),
bytes
))
self
.
assertEqual
(
b
"123"
,
cpt
.
to_bytes
(
b
"123"
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
u
""
),
bytes
))
self
.
assertEqual
(
b
""
,
cpt
.
to_bytes
(
u
""
))
self
.
assertEqual
(
b
"123"
,
cpt
.
to_bytes
(
u
"123"
))
self
.
assertEqual
(
b
""
,
cpt
.
to_bytes
(
str
(
""
)))
self
.
assertEqual
(
b
"123"
,
cpt
.
to_bytes
(
str
(
"123"
)))
# check list types, not inplace
self
.
assertEqual
(
b
""
,
cpt
.
to_bytes
(
b
""
))
l
=
[
""
]
self
.
assertEqual
(
b
"123"
,
cpt
.
to_bytes
(
b
"123"
))
l2
=
cpt
.
to_bytes
(
l
)
self
.
assertEqual
(
b
""
,
cpt
.
to_bytes
(
u
""
))
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertEqual
(
b
"123"
,
cpt
.
to_bytes
(
u
"123"
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
# check list types, not inplace
self
.
assertEqual
([
b
""
],
l2
)
l
=
[
""
]
l
=
[
""
,
"123"
]
l2
=
cpt
.
to_bytes
(
l
)
l2
=
cpt
.
to_bytes
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
([
b
""
],
l2
)
self
.
assertEqual
([
b
""
,
b
"123"
],
l2
)
l
=
[
""
,
"123"
]
l
=
[
""
,
b
"123"
,
u
"321"
]
l2
=
cpt
.
to_bytes
(
l
)
l2
=
cpt
.
to_bytes
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
([
b
""
,
b
"123"
],
l2
)
self
.
assertEqual
([
b
""
,
b
"123"
,
b
"321"
],
l2
)
l
=
[
""
,
b
'123'
,
u
"321"
]
l2
=
cpt
.
to_bytes
(
l
)
# check list types, inplace
self
.
assertTrue
(
isinstance
(
l2
,
list
))
l
=
[
""
]
self
.
assertFalse
(
l
is
l2
)
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertEqual
([
b
""
,
b
"123"
,
b
"321"
],
l2
)
self
.
assertTrue
(
l
is
l2
)
for
i
in
l2
:
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
i
,
bytes
))
self
.
assertEqual
([
b
""
],
l2
)
l
=
[
""
,
b
"123"
]
# check list types, inplace
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
l
=
[
""
]
self
.
assertTrue
(
isinstance
(
l2
,
list
))
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
([
b
""
,
b
"123"
],
l2
)
self
.
assertEqual
(
l
,
l2
)
l
=
[
""
,
b
"123"
,
u
"321"
]
self
.
assertEqual
([
b
""
],
l2
)
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
l
=
[
""
,
"123"
]
self
.
assertTrue
(
isinstance
(
l2
,
list
))
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
([
b
""
,
b
"123"
,
b
"321"
],
l2
)
self
.
assertEqual
(
l
,
l2
)
for
i
in
l2
:
self
.
assertEqual
([
b
""
,
b
"123"
],
l2
)
self
.
assertTrue
(
isinstance
(
i
,
bytes
))
l
=
[
""
,
b
"123"
,
u
"321"
]
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
# check set types, not inplace
self
.
assertTrue
(
isinstance
(
l2
,
list
))
l
=
set
([
""
])
self
.
assertTrue
(
l
is
l2
)
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
False
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertEqual
([
b
""
,
b
"123"
,
b
"321"
],
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
# check set types, not inplace
self
.
assertEqual
(
set
([
b
""
]),
l2
)
l
=
set
(
""
)
l
=
set
([
u
""
,
u
"123"
])
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
False
)
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
(
set
(
b
""
),
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
]),
l2
)
l
=
set
([
b
""
,
b
"123"
])
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
False
)
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertFalse
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
]),
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
,
b
"321"
]),
l2
)
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
False
)
# check set types, inplace
self
.
assertTrue
(
isinstance
(
l2
,
set
))
l
=
set
(
""
)
self
.
assertFalse
(
l
is
l2
)
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertEqual
(
set
([
b
""
,
b
"123"
,
b
"321"
]),
l2
)
self
.
assertTrue
(
l
is
l2
)
for
i
in
l2
:
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
isinstance
(
i
,
bytes
))
self
.
assertEqual
(
set
(
b
""
),
l2
)
l
=
set
([
u
""
,
u
"123"
])
# check set types, inplace
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
l
=
set
(
""
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
]),
l2
)
self
.
assertEqual
(
l
,
l2
)
l
=
set
([
""
,
b
"123"
,
u
"321"
])
self
.
assertEqual
(
set
(
b
""
),
l2
)
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
l
=
set
([
b
""
,
b
"123"
])
self
.
assertTrue
(
isinstance
(
l2
,
set
))
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertEqual
(
l
,
l2
)
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
,
b
"321"
]),
l2
)
self
.
assertEqual
(
l
,
l2
)
for
i
in
l2
:
self
.
assertEqual
(
set
([
b
""
,
b
"123"
]),
l2
)
self
.
assertTrue
(
isinstance
(
i
,
bytes
))
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
,
b
"321"
]),
l2
)
elif
six
.
PY3
:
self
.
assertIsNone
(
cpt
.
to_bytes
(
None
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
str
(
""
)),
bytes
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
str
(
"123"
)),
bytes
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
b
""
),
bytes
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
b
""
),
bytes
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
u
""
),
bytes
))
self
.
assertTrue
(
isinstance
(
cpt
.
to_bytes
(
u
""
),
bytes
))
self
.
assertEqual
(
b
""
,
cpt
.
to_bytes
(
str
(
""
)))
self
.
assertEqual
(
b
"123"
,
cpt
.
to_bytes
(
str
(
"123"
)))
self
.
assertEqual
(
b
""
,
cpt
.
to_bytes
(
b
""
))
self
.
assertEqual
(
b
"123"
,
cpt
.
to_bytes
(
b
"123"
))
self
.
assertEqual
(
b
""
,
cpt
.
to_bytes
(
u
""
))
self
.
assertEqual
(
b
"123"
,
cpt
.
to_bytes
(
u
"123"
))
# check list types, not inplace
l
=
[
""
]
l2
=
cpt
.
to_bytes
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
([
b
""
],
l2
)
l
=
[
""
,
"123"
]
l2
=
cpt
.
to_bytes
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
([
b
""
,
b
"123"
],
l2
)
l
=
[
""
,
b
"123"
,
u
"321"
]
l2
=
cpt
.
to_bytes
(
l
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
([
b
""
,
b
"123"
,
b
"321"
],
l2
)
# check list types, inplace
l
=
[
""
]
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
([
b
""
],
l2
)
l
=
[
""
,
b
"123"
]
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
([
b
""
,
b
"123"
],
l2
)
l
=
[
""
,
b
"123"
,
u
"321"
]
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
list
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
([
b
""
,
b
"123"
,
b
"321"
],
l2
)
for
i
in
l2
:
self
.
assertTrue
(
isinstance
(
i
,
bytes
))
# check set types, not inplace
l
=
set
([
""
])
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
b
""
]),
l2
)
l
=
set
([
u
""
,
u
"123"
])
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
]),
l2
)
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
False
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertFalse
(
l
is
l2
)
self
.
assertNotEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
,
b
"321"
]),
l2
)
# check set types, inplace
l
=
set
(
""
)
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
(
set
(
b
""
),
l2
)
l
=
set
([
u
""
,
u
"123"
])
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
]),
l2
)
l
=
set
([
""
,
b
"123"
,
u
"321"
])
l2
=
cpt
.
to_bytes
(
l
,
inplace
=
True
)
self
.
assertTrue
(
isinstance
(
l2
,
set
))
self
.
assertTrue
(
l
is
l2
)
self
.
assertEqual
(
l
,
l2
)
self
.
assertEqual
(
set
([
b
""
,
b
"123"
,
b
"321"
]),
l2
)
for
i
in
l2
:
self
.
assertTrue
(
isinstance
(
i
,
bytes
))
def
test_round
(
self
):
def
test_round
(
self
):
self
.
assertEqual
(
3.0
,
cpt
.
round
(
3.4
))
self
.
assertEqual
(
3.0
,
cpt
.
round
(
3.4
))
...
@@ -500,37 +265,17 @@ class TestCompatible(unittest.TestCase):
...
@@ -500,37 +265,17 @@ class TestCompatible(unittest.TestCase):
def
test_get_exception_message
(
self
):
def
test_get_exception_message
(
self
):
exception_message
=
"test_message"
exception_message
=
"test_message"
self
.
assertRaises
(
AssertionError
,
cpt
.
get_exception_message
,
None
)
self
.
assertRaises
(
AssertionError
,
cpt
.
get_exception_message
,
None
)
if
six
.
PY2
:
try
:
self
.
assertRaises
(
AttributeError
,
cpt
.
get_exception_message
,
raise
RuntimeError
(
exception_message
)
exception_message
)
except
Exception
as
e
:
try
:
self
.
assertEqual
(
exception_message
,
cpt
.
get_exception_message
(
e
))
raise
RuntimeError
(
exception_message
)
self
.
assertIsNotNone
(
e
)
except
Exception
as
e
:
self
.
assertEqual
(
exception_message
,
try
:
cpt
.
get_exception_message
(
e
))
raise
Exception
(
exception_message
)
self
.
assertIsNotNone
(
e
)
except
Exception
as
e
:
self
.
assertEqual
(
exception_message
,
cpt
.
get_exception_message
(
e
))
try
:
self
.
assertIsNotNone
(
e
)
raise
Exception
(
exception_message
)
except
Exception
as
e
:
self
.
assertEqual
(
exception_message
,
cpt
.
get_exception_message
(
e
))
self
.
assertIsNotNone
(
e
)
if
six
.
PY3
:
try
:
raise
RuntimeError
(
exception_message
)
except
Exception
as
e
:
self
.
assertEqual
(
exception_message
,
cpt
.
get_exception_message
(
e
))
self
.
assertIsNotNone
(
e
)
try
:
raise
Exception
(
exception_message
)
except
Exception
as
e
:
self
.
assertEqual
(
exception_message
,
cpt
.
get_exception_message
(
e
))
self
.
assertIsNotNone
(
e
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
0f7187af
...
@@ -44,19 +44,13 @@ DIST_UT_PORT = 0
...
@@ -44,19 +44,13 @@ DIST_UT_PORT = 0
def
print_to_out
(
out_losses
):
def
print_to_out
(
out_losses
):
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
print
(
pickle
.
dumps
(
out_losses
))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
def
print_to_err
(
class_name
,
log_str
):
def
print_to_err
(
class_name
,
log_str
):
localtime
=
time
.
asctime
(
time
.
localtime
(
time
.
time
()))
localtime
=
time
.
asctime
(
time
.
localtime
(
time
.
time
()))
print_str
=
localtime
+
"
\t
"
+
class_name
+
"
\t
"
+
log_str
print_str
=
localtime
+
"
\t
"
+
class_name
+
"
\t
"
+
log_str
if
six
.
PY2
:
sys
.
stderr
.
buffer
.
write
(
pickle
.
dumps
(
print_str
))
sys
.
stderr
.
write
(
pickle
.
dumps
(
print_str
))
else
:
sys
.
stderr
.
buffer
.
write
(
pickle
.
dumps
(
print_str
))
def
eprint
(
*
args
,
**
kwargs
):
def
eprint
(
*
args
,
**
kwargs
):
...
@@ -151,10 +145,7 @@ class TestDistRunnerBase(object):
...
@@ -151,10 +145,7 @@ class TestDistRunnerBase(object):
print_to_err
(
type
(
self
).
__name__
,
"run step %d finished"
%
i
)
print_to_err
(
type
(
self
).
__name__
,
"run step %d finished"
%
i
)
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
print
(
pickle
.
dumps
(
out_losses
))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
if
args
.
save_model
:
if
args
.
save_model
:
model_save_dir
=
"/tmp"
model_save_dir
=
"/tmp"
...
@@ -251,10 +242,7 @@ class TestDistRunnerBase(object):
...
@@ -251,10 +242,7 @@ class TestDistRunnerBase(object):
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
print_to_err
(
type
(
self
).
__name__
,
"dist losses: {}"
.
format
(
out_losses
))
print_to_err
(
type
(
self
).
__name__
,
"dist losses: {}"
.
format
(
out_losses
))
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
print
(
pickle
.
dumps
(
out_losses
))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
def
run_use_fleet_api_trainer
(
self
,
args
):
def
run_use_fleet_api_trainer
(
self
,
args
):
assert
args
.
update_method
==
"nccl2"
or
"bkcl"
assert
args
.
update_method
==
"nccl2"
or
"bkcl"
...
@@ -338,10 +326,7 @@ class TestDistRunnerBase(object):
...
@@ -338,10 +326,7 @@ class TestDistRunnerBase(object):
print_to_err
(
type
(
self
).
__name__
,
"run step %d finished"
%
i
)
print_to_err
(
type
(
self
).
__name__
,
"run step %d finished"
%
i
)
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
if
six
.
PY2
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
print
(
pickle
.
dumps
(
out_losses
))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
if
args
.
save_model
:
if
args
.
save_model
:
model_save_dir
=
"/tmp"
model_save_dir
=
"/tmp"
...
...
python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py
浏览文件 @
0f7187af
...
@@ -18,7 +18,6 @@ import unittest
...
@@ -18,7 +18,6 @@ import unittest
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
numpy
as
np
import
numpy
as
np
import
six
import
inspect
import
inspect
...
@@ -241,10 +240,7 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
...
@@ -241,10 +240,7 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
a
=
fluid
.
dygraph
.
to_variable
(
np
.
array
([
100.1
]))
a
=
fluid
.
dygraph
.
to_variable
(
np
.
array
([
100.1
]))
self
.
assertTrue
(
float
(
a
)
==
100.1
)
self
.
assertTrue
(
float
(
a
)
==
100.1
)
self
.
assertTrue
(
int
(
a
)
==
100
)
self
.
assertTrue
(
int
(
a
)
==
100
)
if
six
.
PY2
:
self
.
assertTrue
(
int
(
a
)
==
100
)
self
.
assertTrue
(
long
(
a
)
==
100
)
else
:
self
.
assertTrue
(
int
(
a
)
==
100
)
def
test_len
(
self
):
def
test_len
(
self
):
a_np
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
shape
).
astype
(
self
.
dtype
)
a_np
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
shape
).
astype
(
self
.
dtype
)
...
...
python/paddle/fluid/tests/unittests/test_paddle_save_load.py
浏览文件 @
0f7187af
...
@@ -18,7 +18,6 @@ import unittest
...
@@ -18,7 +18,6 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
sys
import
sys
import
six
from
io
import
BytesIO
from
io
import
BytesIO
import
paddle
import
paddle
...
@@ -38,10 +37,7 @@ SEED = 10
...
@@ -38,10 +37,7 @@ SEED = 10
IMAGE_SIZE
=
784
IMAGE_SIZE
=
784
CLASS_NUM
=
10
CLASS_NUM
=
10
if
six
.
PY2
:
LARGE_PARAM
=
2
**
26
LARGE_PARAM
=
2
**
2
else
:
LARGE_PARAM
=
2
**
26
def
random_batch_reader
():
def
random_batch_reader
():
...
@@ -105,10 +101,7 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
...
@@ -105,10 +101,7 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
path
=
os
.
path
.
join
(
"test_paddle_save_load_large_param_save"
,
path
=
os
.
path
.
join
(
"test_paddle_save_load_large_param_save"
,
"layer.pdparams"
)
"layer.pdparams"
)
if
six
.
PY2
:
protocol
=
4
protocol
=
2
else
:
protocol
=
4
paddle
.
save
(
save_dict
,
path
,
protocol
=
protocol
)
paddle
.
save
(
save_dict
,
path
,
protocol
=
protocol
)
dict_load
=
paddle
.
load
(
path
)
dict_load
=
paddle
.
load
(
path
)
# compare results before and after saving
# compare results before and after saving
...
@@ -926,9 +919,6 @@ class TestSaveLoadProgram(unittest.TestCase):
...
@@ -926,9 +919,6 @@ class TestSaveLoadProgram(unittest.TestCase):
class
TestSaveLoadLayer
(
unittest
.
TestCase
):
class
TestSaveLoadLayer
(
unittest
.
TestCase
):
def
test_save_load_layer
(
self
):
def
test_save_load_layer
(
self
):
if
six
.
PY2
:
return
paddle
.
disable_static
()
paddle
.
disable_static
()
inps
=
paddle
.
randn
([
1
,
IMAGE_SIZE
],
dtype
=
'float32'
)
inps
=
paddle
.
randn
([
1
,
IMAGE_SIZE
],
dtype
=
'float32'
)
layer1
=
LinearNet
()
layer1
=
LinearNet
()
...
...
python/paddle/fluid/tests/unittests/test_static_save_load_large.py
浏览文件 @
0f7187af
...
@@ -21,15 +21,10 @@ import paddle.fluid.framework as framework
...
@@ -21,15 +21,10 @@ import paddle.fluid.framework as framework
from
test_imperative_base
import
new_program_scope
from
test_imperative_base
import
new_program_scope
import
numpy
as
np
import
numpy
as
np
import
six
import
pickle
import
pickle
import
os
import
os
# Python2.x no longer supports saving and loading large parameters.
LARGE_PARAM
=
2
**
26
if
six
.
PY2
:
LARGE_PARAM
=
2
else
:
LARGE_PARAM
=
2
**
26
class
TestStaticSaveLoadLargeParameters
(
unittest
.
TestCase
):
class
TestStaticSaveLoadLargeParameters
(
unittest
.
TestCase
):
...
@@ -59,10 +54,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
...
@@ -59,10 +54,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
path
=
os
.
path
.
join
(
"test_static_save_load_large_param"
,
path
=
os
.
path
.
join
(
"test_static_save_load_large_param"
,
"static_save"
)
"static_save"
)
if
six
.
PY2
:
protocol
=
4
protocol
=
2
else
:
protocol
=
4
paddle
.
fluid
.
save
(
prog
,
path
,
pickle_protocol
=
protocol
)
paddle
.
fluid
.
save
(
prog
,
path
,
pickle_protocol
=
protocol
)
# set var to zero
# set var to zero
for
var
in
prog
.
list_vars
():
for
var
in
prog
.
list_vars
():
...
...
python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py
浏览文件 @
0f7187af
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
six
import
unittest
import
unittest
import
paddle.nn
as
nn
import
paddle.nn
as
nn
import
os
import
os
...
@@ -50,10 +49,7 @@ class TestTracedLayerErrMsg(unittest.TestCase):
...
@@ -50,10 +49,7 @@ class TestTracedLayerErrMsg(unittest.TestCase):
self
.
feature_size
=
3
self
.
feature_size
=
3
self
.
fc_size
=
2
self
.
fc_size
=
2
self
.
layer
=
self
.
_train_simple_net
()
self
.
layer
=
self
.
_train_simple_net
()
if
six
.
PY2
:
self
.
type_str
=
'class'
self
.
type_str
=
'type'
else
:
self
.
type_str
=
'class'
def
test_trace_err
(
self
):
def
test_trace_err
(
self
):
with
fluid
.
dygraph
.
guard
():
with
fluid
.
dygraph
.
guard
():
...
...
python/paddle/fluid/tests/unittests/test_var_base.py
浏览文件 @
0f7187af
...
@@ -192,7 +192,7 @@ class TestVarBase(unittest.TestCase):
...
@@ -192,7 +192,7 @@ class TestVarBase(unittest.TestCase):
x
=
paddle
.
to_tensor
(
1
,
dtype
=
'int64'
)
x
=
paddle
.
to_tensor
(
1
,
dtype
=
'int64'
)
self
.
assertEqual
(
x
.
item
(),
1
)
self
.
assertEqual
(
x
.
item
(),
1
)
self
.
assertTrue
(
isinstance
(
x
.
item
(),
long
if
six
.
PY2
else
int
))
self
.
assertTrue
(
isinstance
(
x
.
item
(),
int
))
x
=
paddle
.
to_tensor
(
True
)
x
=
paddle
.
to_tensor
(
True
)
self
.
assertEqual
(
x
.
item
(),
True
)
self
.
assertEqual
(
x
.
item
(),
True
)
...
...
python/paddle/framework/io.py
浏览文件 @
0f7187af
...
@@ -17,14 +17,10 @@ from __future__ import print_function
...
@@ -17,14 +17,10 @@ from __future__ import print_function
import
os
import
os
import
collections
import
collections
import
pickle
import
pickle
import
six
import
warnings
import
warnings
import
sys
import
sys
import
numpy
as
np
import
numpy
as
np
import
copyreg
if
not
six
.
PY2
:
import
copyreg
import
paddle
import
paddle
# deprecated module import
# deprecated module import
...
@@ -296,19 +292,14 @@ def _pickle_save(obj, f, protocol):
...
@@ -296,19 +292,14 @@ def _pickle_save(obj, f, protocol):
for
i
in
range
(
0
,
len
(
pickle_bytes
),
max_bytes
):
for
i
in
range
(
0
,
len
(
pickle_bytes
),
max_bytes
):
f
.
write
(
pickle_bytes
[
i
:
i
+
max_bytes
])
f
.
write
(
pickle_bytes
[
i
:
i
+
max_bytes
])
else
:
else
:
if
six
.
PY2
:
pickler
=
pickle
.
Pickler
(
f
,
protocol
)
add_dispatch_table
()
pickler
.
dispatch_table
=
copyreg
.
dispatch_table
.
copy
()
pickle_bytes
=
pickle
.
dump
(
obj
,
f
,
protocol
)
pop_dispatch_table
()
else
:
pickler
=
pickle
.
Pickler
(
f
,
protocol
)
pickler
.
dispatch_table
=
copyreg
.
dispatch_table
.
copy
()
pickler
.
dispatch_table
[
core
.
VarBase
]
=
reduce_varbase
pickler
.
dispatch_table
[
core
.
VarBase
]
=
reduce_varbase
pickler
.
dispatch_table
[
core
.
LoDTensor
]
=
reduce_LoDTensor
pickler
.
dispatch_table
[
core
.
LoDTensor
]
=
reduce_LoDTensor
pickler
.
dispatch_table
[
ParamBase
]
=
reduce_varbase
pickler
.
dispatch_table
[
ParamBase
]
=
reduce_varbase
pickler
.
dispatch_table
.
update
(
dispatch_table_layer
)
pickler
.
dispatch_table
.
update
(
dispatch_table_layer
)
pickler
.
dump
(
obj
)
pickler
.
dump
(
obj
)
def
_contain_x
(
obj
,
condition_func
):
def
_contain_x
(
obj
,
condition_func
):
...
@@ -359,10 +350,7 @@ def _transformed_from_varbase(obj):
...
@@ -359,10 +350,7 @@ def _transformed_from_varbase(obj):
# In paddle2.1 version, VarBase is saved as tuple(tensor.name, tensor.numpy()).
# In paddle2.1 version, VarBase is saved as tuple(tensor.name, tensor.numpy()).
# When executing paddle.load, use this function to determine whether to restore to VarBase/LoDTensor.
# When executing paddle.load, use this function to determine whether to restore to VarBase/LoDTensor.
if
isinstance
(
obj
,
tuple
)
and
len
(
obj
)
==
2
:
if
isinstance
(
obj
,
tuple
)
and
len
(
obj
)
==
2
:
if
six
.
PY2
:
name_types
=
str
name_types
=
(
str
,
unicode
)
else
:
name_types
=
str
if
isinstance
(
obj
[
0
],
name_types
)
and
isinstance
(
obj
[
1
],
np
.
ndarray
):
if
isinstance
(
obj
[
0
],
name_types
)
and
isinstance
(
obj
[
1
],
np
.
ndarray
):
return
True
return
True
return
False
return
False
...
@@ -947,10 +935,7 @@ def load(path, **configs):
...
@@ -947,10 +935,7 @@ def load(path, **configs):
if
_is_memory_buffer
(
path
)
or
os
.
path
.
isfile
(
path
):
if
_is_memory_buffer
(
path
)
or
os
.
path
.
isfile
(
path
):
config
=
_parse_load_config
(
configs
)
config
=
_parse_load_config
(
configs
)
if
six
.
PY2
:
exception_type
=
pickle
.
UnpicklingError
exception_type
=
KeyError
else
:
exception_type
=
pickle
.
UnpicklingError
try
:
try
:
with
_open_file_buffer
(
path
,
'rb'
)
as
f
:
with
_open_file_buffer
(
path
,
'rb'
)
as
f
:
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
...
@@ -959,8 +944,7 @@ def load(path, **configs):
...
@@ -959,8 +944,7 @@ def load(path, **configs):
)
and
sys
.
platform
==
'darwin'
and
sys
.
version_info
.
major
==
3
:
)
and
sys
.
platform
==
'darwin'
and
sys
.
version_info
.
major
==
3
:
load_result
=
_pickle_loads_mac
(
path
,
f
)
load_result
=
_pickle_loads_mac
(
path
,
f
)
else
:
else
:
load_result
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
load_result
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
f
,
encoding
=
'latin1'
)
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
if
isinstance
(
load_result
,
dict
):
if
isinstance
(
load_result
,
dict
):
...
@@ -1021,8 +1005,7 @@ def _legacy_load(path, **configs):
...
@@ -1021,8 +1005,7 @@ def _legacy_load(path, **configs):
if
os
.
path
.
isfile
(
path
)
or
_is_memory_buffer
(
path
):
if
os
.
path
.
isfile
(
path
)
or
_is_memory_buffer
(
path
):
# we think path is file means this file is created by paddle.save
# we think path is file means this file is created by paddle.save
with
_open_file_buffer
(
path
,
'rb'
)
as
f
:
with
_open_file_buffer
(
path
,
'rb'
)
as
f
:
load_result
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
load_result
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
f
,
encoding
=
'latin1'
)
load_result
=
_pack_loaded_dict
(
load_result
)
load_result
=
_pack_loaded_dict
(
load_result
)
if
not
config
.
keep_name_table
and
"StructuredToParameterName@@"
in
load_result
:
if
not
config
.
keep_name_table
and
"StructuredToParameterName@@"
in
load_result
:
del
load_result
[
"StructuredToParameterName@@"
]
del
load_result
[
"StructuredToParameterName@@"
]
...
...
python/paddle/hapi/model.py
浏览文件 @
0f7187af
...
@@ -1296,8 +1296,7 @@ class Model(object):
...
@@ -1296,8 +1296,7 @@ class Model(object):
if
not
os
.
path
.
exists
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
return
with
open
(
path
,
'rb'
)
as
f
:
with
open
(
path
,
'rb'
)
as
f
:
return
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
return
pickle
.
load
(
f
,
encoding
=
'latin1'
)
f
,
encoding
=
'latin1'
)
def
_check_match
(
key
,
param
):
def
_check_match
(
key
,
param
):
state
=
param_state
.
get
(
key
,
None
)
state
=
param_state
.
get
(
key
,
None
)
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
0f7187af
...
@@ -21,7 +21,6 @@ from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_t
...
@@ -21,7 +21,6 @@ from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_t
from
..fluid.layers.tensor
import
fill_constant
from
..fluid.layers.tensor
import
fill_constant
from
..fluid.layers
import
utils
from
..fluid.layers
import
utils
import
numpy
as
np
import
numpy
as
np
import
six
# TODO: define functions to manipulate a tensor
# TODO: define functions to manipulate a tensor
from
..fluid.layers
import
cast
# noqa: F401
from
..fluid.layers
import
cast
# noqa: F401
from
..fluid.layers
import
slice
# noqa: F401
from
..fluid.layers
import
slice
# noqa: F401
...
@@ -1218,10 +1217,7 @@ def tile(x, repeat_times, name=None):
...
@@ -1218,10 +1217,7 @@ def tile(x, repeat_times, name=None):
assert
len
(
elem
.
shape
)
==
1
,
(
assert
len
(
elem
.
shape
)
==
1
,
(
'Elements in repeat_times must be 1-D Tensors or integers.'
)
'Elements in repeat_times must be 1-D Tensors or integers.'
)
else
:
else
:
if
six
.
PY3
:
type_tuple
=
(
int
,
np
.
int32
,
np
.
int64
)
type_tuple
=
(
int
,
np
.
int32
,
np
.
int64
)
elif
six
.
PY2
:
type_tuple
=
(
int
,
long
,
np
.
int32
,
np
.
int64
)
assert
isinstance
(
elem
,
type_tuple
),
(
assert
isinstance
(
elem
,
type_tuple
),
(
'Elements in repeat_times must be 1-D Tensors or integers.'
)
'Elements in repeat_times must be 1-D Tensors or integers.'
)
...
@@ -1357,10 +1353,7 @@ def broadcast_to(x, shape, name=None):
...
@@ -1357,10 +1353,7 @@ def broadcast_to(x, shape, name=None):
assert
len
(
elem
.
shape
)
==
1
,
(
assert
len
(
elem
.
shape
)
==
1
,
(
'Elements in shape must be 1-D Tensors or integers.'
)
'Elements in shape must be 1-D Tensors or integers.'
)
else
:
else
:
if
six
.
PY3
:
type_tuple
=
(
int
,
np
.
int32
,
np
.
int64
)
type_tuple
=
(
int
,
np
.
int32
,
np
.
int64
)
elif
six
.
PY2
:
type_tuple
=
(
int
,
long
,
np
.
int32
,
np
.
int64
)
assert
isinstance
(
elem
,
type_tuple
),
(
assert
isinstance
(
elem
,
type_tuple
),
(
'Elements in shape must be 1-D Tensors or integers.'
)
'Elements in shape must be 1-D Tensors or integers.'
)
...
@@ -1447,10 +1440,7 @@ def expand(x, shape, name=None):
...
@@ -1447,10 +1440,7 @@ def expand(x, shape, name=None):
assert
len
(
elem
.
shape
)
==
1
,
(
assert
len
(
elem
.
shape
)
==
1
,
(
'Elements in shape must be 1-D Tensors or integers.'
)
'Elements in shape must be 1-D Tensors or integers.'
)
else
:
else
:
if
six
.
PY3
:
type_tuple
=
(
int
,
np
.
int32
,
np
.
int64
)
type_tuple
=
(
int
,
np
.
int32
,
np
.
int64
)
elif
six
.
PY2
:
type_tuple
=
(
int
,
long
,
np
.
int32
,
np
.
int64
)
assert
isinstance
(
elem
,
type_tuple
),
(
assert
isinstance
(
elem
,
type_tuple
),
(
'Elements in shape must be 1-D Tensors or integers.'
)
'Elements in shape must be 1-D Tensors or integers.'
)
...
...
python/paddle/utils/cpp_extension/extension_utils.py
浏览文件 @
0f7187af
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
os
import
os
import
re
import
re
import
six
import
sys
import
sys
import
json
import
json
import
glob
import
glob
...
@@ -541,8 +540,7 @@ def find_cuda_home():
...
@@ -541,8 +540,7 @@ def find_cuda_home():
with
open
(
os
.
devnull
,
'w'
)
as
devnull
:
with
open
(
os
.
devnull
,
'w'
)
as
devnull
:
nvcc_path
=
subprocess
.
check_output
(
nvcc_path
=
subprocess
.
check_output
(
[
which_cmd
,
'nvcc'
],
stderr
=
devnull
)
[
which_cmd
,
'nvcc'
],
stderr
=
devnull
)
if
six
.
PY3
:
nvcc_path
=
nvcc_path
.
decode
()
nvcc_path
=
nvcc_path
.
decode
()
# Multi CUDA, select the first
# Multi CUDA, select the first
nvcc_path
=
nvcc_path
.
split
(
'
\r\n
'
)[
0
]
nvcc_path
=
nvcc_path
.
split
(
'
\r\n
'
)[
0
]
...
@@ -580,8 +578,7 @@ def find_rocm_home():
...
@@ -580,8 +578,7 @@ def find_rocm_home():
with
open
(
os
.
devnull
,
'w'
)
as
devnull
:
with
open
(
os
.
devnull
,
'w'
)
as
devnull
:
hipcc_path
=
subprocess
.
check_output
(
hipcc_path
=
subprocess
.
check_output
(
[
which_cmd
,
'hipcc'
],
stderr
=
devnull
)
[
which_cmd
,
'hipcc'
],
stderr
=
devnull
)
if
six
.
PY3
:
hipcc_path
=
hipcc_path
.
decode
()
hipcc_path
=
hipcc_path
.
decode
()
hipcc_path
=
hipcc_path
.
rstrip
(
'
\r\n
'
)
hipcc_path
=
hipcc_path
.
rstrip
(
'
\r\n
'
)
# for example: /opt/rocm/bin/hipcc
# for example: /opt/rocm/bin/hipcc
...
@@ -652,8 +649,7 @@ def find_clang_cpp_include(compiler='clang'):
...
@@ -652,8 +649,7 @@ def find_clang_cpp_include(compiler='clang'):
std_v1_includes
=
None
std_v1_includes
=
None
try
:
try
:
compiler_version
=
subprocess
.
check_output
([
compiler
,
"--version"
])
compiler_version
=
subprocess
.
check_output
([
compiler
,
"--version"
])
if
six
.
PY3
:
compiler_version
=
compiler_version
.
decode
()
compiler_version
=
compiler_version
.
decode
()
infos
=
compiler_version
.
split
(
"
\n
"
)
infos
=
compiler_version
.
split
(
"
\n
"
)
for
info
in
infos
:
for
info
in
infos
:
if
"InstalledDir"
in
info
:
if
"InstalledDir"
in
info
:
...
@@ -895,13 +891,9 @@ def _load_module_from_file(api_file_path, verbose=False):
...
@@ -895,13 +891,9 @@ def _load_module_from_file(api_file_path, verbose=False):
# Unique readable module name to place custom api.
# Unique readable module name to place custom api.
log_v
(
'import module from file: {}'
.
format
(
api_file_path
),
verbose
)
log_v
(
'import module from file: {}'
.
format
(
api_file_path
),
verbose
)
ext_name
=
"_paddle_cpp_extension_"
ext_name
=
"_paddle_cpp_extension_"
if
six
.
PY2
:
from
importlib
import
machinery
import
imp
loader
=
machinery
.
SourceFileLoader
(
ext_name
,
api_file_path
)
module
=
imp
.
load_source
(
ext_name
,
api_file_path
)
module
=
loader
.
load_module
()
else
:
from
importlib
import
machinery
loader
=
machinery
.
SourceFileLoader
(
ext_name
,
api_file_path
)
module
=
loader
.
load_module
()
return
module
return
module
...
@@ -1005,8 +997,7 @@ def _jit_compile(file_path, verbose=False):
...
@@ -1005,8 +997,7 @@ def _jit_compile(file_path, verbose=False):
try
:
try
:
py_version
=
subprocess
.
check_output
([
interpreter
,
'-V'
])
py_version
=
subprocess
.
check_output
([
interpreter
,
'-V'
])
if
six
.
PY3
:
py_version
=
py_version
.
decode
()
py_version
=
py_version
.
decode
()
log_v
(
"Using Python interpreter: {}, version: {}"
.
format
(
log_v
(
"Using Python interpreter: {}, version: {}"
.
format
(
interpreter
,
py_version
.
strip
()),
verbose
)
interpreter
,
py_version
.
strip
()),
verbose
)
except
Exception
:
except
Exception
:
...
@@ -1083,8 +1074,7 @@ def check_abi_compatibility(compiler, verbose=False):
...
@@ -1083,8 +1074,7 @@ def check_abi_compatibility(compiler, verbose=False):
if
not
IS_WINDOWS
:
if
not
IS_WINDOWS
:
cmd_out
=
subprocess
.
check_output
(
cmd_out
=
subprocess
.
check_output
(
[
'which'
,
compiler
],
stderr
=
subprocess
.
STDOUT
)
[
'which'
,
compiler
],
stderr
=
subprocess
.
STDOUT
)
compiler_path
=
os
.
path
.
realpath
(
cmd_out
.
decode
()
compiler_path
=
os
.
path
.
realpath
(
cmd_out
.
decode
()).
strip
()
if
six
.
PY3
else
cmd_out
).
strip
()
# if not found any suitable compiler, raise warning
# if not found any suitable compiler, raise warning
if
not
any
(
name
in
compiler_path
if
not
any
(
name
in
compiler_path
for
name
in
_expected_compiler_current_platform
()):
for
name
in
_expected_compiler_current_platform
()):
...
@@ -1104,18 +1094,16 @@ def check_abi_compatibility(compiler, verbose=False):
...
@@ -1104,18 +1094,16 @@ def check_abi_compatibility(compiler, verbose=False):
mini_required_version
=
GCC_MINI_VERSION
mini_required_version
=
GCC_MINI_VERSION
version_info
=
subprocess
.
check_output
(
version_info
=
subprocess
.
check_output
(
[
compiler
,
'-dumpfullversion'
,
'-dumpversion'
])
[
compiler
,
'-dumpfullversion'
,
'-dumpversion'
])
if
six
.
PY3
:
version_info
=
version_info
.
decode
()
version_info
=
version_info
.
decode
()
version
=
version_info
.
strip
().
split
(
'.'
)
version
=
version_info
.
strip
().
split
(
'.'
)
elif
IS_WINDOWS
:
elif
IS_WINDOWS
:
mini_required_version
=
MSVC_MINI_VERSION
mini_required_version
=
MSVC_MINI_VERSION
compiler_info
=
subprocess
.
check_output
(
compiler_info
=
subprocess
.
check_output
(
compiler
,
stderr
=
subprocess
.
STDOUT
)
compiler
,
stderr
=
subprocess
.
STDOUT
)
if
six
.
PY3
:
try
:
try
:
compiler_info
=
compiler_info
.
decode
(
'UTF-8'
)
compiler_info
=
compiler_info
.
decode
(
'UTF-8'
)
except
UnicodeDecodeError
:
except
UnicodeDecodeError
:
compiler_info
=
compiler_info
.
decode
(
'gbk'
)
compiler_info
=
compiler_info
.
decode
(
'gbk'
)
match
=
re
.
search
(
r
'(\d+)\.(\d+)\.(\d+)'
,
compiler_info
.
strip
())
match
=
re
.
search
(
r
'(\d+)\.(\d+)\.(\d+)'
,
compiler_info
.
strip
())
if
match
is
not
None
:
if
match
is
not
None
:
version
=
match
.
groups
()
version
=
match
.
groups
()
...
...
python/paddle/vision/datasets/cifar.py
浏览文件 @
0f7187af
...
@@ -141,10 +141,7 @@ class Cifar10(Dataset):
...
@@ -141,10 +141,7 @@ class Cifar10(Dataset):
if
self
.
flag
in
each_item
.
name
)
if
self
.
flag
in
each_item
.
name
)
for
name
in
names
:
for
name
in
names
:
if
six
.
PY2
:
batch
=
pickle
.
load
(
f
.
extractfile
(
name
),
encoding
=
'bytes'
)
batch
=
pickle
.
load
(
f
.
extractfile
(
name
))
else
:
batch
=
pickle
.
load
(
f
.
extractfile
(
name
),
encoding
=
'bytes'
)
data
=
batch
[
six
.
b
(
'data'
)]
data
=
batch
[
six
.
b
(
'data'
)]
labels
=
batch
.
get
(
labels
=
batch
.
get
(
...
...
tools/count_api_without_core_ops.py
浏览文件 @
0f7187af
...
@@ -20,7 +20,6 @@ import collections
...
@@ -20,7 +20,6 @@ import collections
import
sys
import
sys
import
pydoc
import
pydoc
import
hashlib
import
hashlib
import
six
import
functools
import
functools
import
platform
import
platform
...
@@ -104,7 +103,7 @@ def visit_member(parent_name, member, func):
...
@@ -104,7 +103,7 @@ def visit_member(parent_name, member, func):
def
is_primitive
(
instance
):
def
is_primitive
(
instance
):
int_types
=
(
int
,
long
)
if
six
.
PY2
else
(
int
,
)
int_types
=
(
int
,
)
pritimitive_types
=
int_types
+
(
float
,
str
)
pritimitive_types
=
int_types
+
(
float
,
str
)
if
isinstance
(
instance
,
pritimitive_types
):
if
isinstance
(
instance
,
pritimitive_types
):
return
True
return
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录