Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b11f0b7a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b11f0b7a
编写于
6月 14, 2023
作者:
S
sneaxiy
提交者:
GitHub
6月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cherry-pick
https://github.com/PaddlePaddle/Paddle/pull/54487
to release/2.5 (#54609)
上级
96564faf
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
313 addition
and
128 deletion
+313
-128
paddle/phi/kernels/coalesce_tensor_kernel.cc
paddle/phi/kernels/coalesce_tensor_kernel.cc
+16
-1
paddle/phi/kernels/gpu/sgd_kernel.cu
paddle/phi/kernels/gpu/sgd_kernel.cu
+18
-0
python/paddle/static/nn/common.py
python/paddle/static/nn/common.py
+6
-3
test/auto_parallel/test_pass_sharding.py
test/auto_parallel/test_pass_sharding.py
+0
-2
test/collective/collective_allgather_api.py
test/collective/collective_allgather_api.py
+1
-3
test/collective/collective_global_gather.py
test/collective/collective_global_gather.py
+2
-3
test/collective/collective_global_scatter.py
test/collective/collective_global_scatter.py
+2
-3
test/collective/fleet/parallel_dygraph_no_sync.py
test/collective/fleet/parallel_dygraph_no_sync.py
+3
-3
test/collective/test_collective_reduce_api.py
test/collective/test_collective_reduce_api.py
+0
-2
test/legacy_test/dist_sharding_save.py
test/legacy_test/dist_sharding_save.py
+5
-3
test/legacy_test/test_collective_api_base.py
test/legacy_test/test_collective_api_base.py
+24
-4
test/legacy_test/test_collective_base.py
test/legacy_test/test_collective_base.py
+20
-3
test/legacy_test/test_dist_base.py
test/legacy_test/test_dist_base.py
+60
-27
test/legacy_test/test_sync_batch_norm_op.py
test/legacy_test/test_sync_batch_norm_op.py
+68
-71
test/legacy_test/test_sync_batch_norm_op_convert.py
test/legacy_test/test_sync_batch_norm_op_convert.py
+88
-0
未找到文件。
paddle/phi/kernels/coalesce_tensor_kernel.cc
浏览文件 @
b11f0b7a
...
@@ -277,7 +277,22 @@ PD_REGISTER_KERNEL(coalesce_tensor,
...
@@ -277,7 +277,22 @@ PD_REGISTER_KERNEL(coalesce_tensor,
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
UNDEFINED
);
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
UNDEFINED
);
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUDA
PD_REGISTER_KERNEL
(
coalesce_tensor
,
GPU
,
ALL_LAYOUT
,
phi
::
CoalesceTensorKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
int
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
UNDEFINED
);
}
#endif
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL
(
coalesce_tensor
,
PD_REGISTER_KERNEL
(
coalesce_tensor
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
...
...
paddle/phi/kernels/gpu/sgd_kernel.cu
浏览文件 @
b11f0b7a
...
@@ -181,6 +181,23 @@ void SGDSparseParamSparseGradKernel(
...
@@ -181,6 +181,23 @@ void SGDSparseParamSparseGradKernel(
}
// namespace phi
}
// namespace phi
#ifdef PADDLE_WITH_CUDA
PD_REGISTER_KERNEL
(
sgd
,
GPU
,
ALL_LAYOUT
,
phi
::
SGDDenseKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
float
,
double
)
{
if
(
kernel_key
.
dtype
()
==
phi
::
DataType
::
FLOAT16
||
kernel_key
.
dtype
()
==
phi
::
DataType
::
BFLOAT16
)
{
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
}
}
#endif
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL
(
sgd
,
PD_REGISTER_KERNEL
(
sgd
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
...
@@ -192,6 +209,7 @@ PD_REGISTER_KERNEL(sgd,
...
@@ -192,6 +209,7 @@ PD_REGISTER_KERNEL(sgd,
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
}
}
}
}
#endif
PD_REGISTER_KERNEL
(
sgd_dense_param_sparse_grad
,
PD_REGISTER_KERNEL
(
sgd_dense_param_sparse_grad
,
GPU
,
GPU
,
...
...
python/paddle/static/nn/common.py
浏览文件 @
b11f0b7a
...
@@ -888,7 +888,7 @@ def conv2d(
...
@@ -888,7 +888,7 @@ def conv2d(
"""
"""
check_variable_and_dtype
(
check_variable_and_dtype
(
input
,
'input'
,
[
'float16'
,
'float32'
,
'float64'
],
'conv2d'
input
,
'input'
,
[
'
uint16'
,
'
float16'
,
'float32'
,
'float64'
],
'conv2d'
)
)
if
len
(
input
.
shape
)
!=
4
:
if
len
(
input
.
shape
)
!=
4
:
raise
ValueError
(
raise
ValueError
(
...
@@ -2742,12 +2742,15 @@ def batch_norm(
...
@@ -2742,12 +2742,15 @@ def batch_norm(
helper
=
LayerHelper
(
'batch_norm'
,
**
locals
())
helper
=
LayerHelper
(
'batch_norm'
,
**
locals
())
check_variable_and_dtype
(
check_variable_and_dtype
(
input
,
'input'
,
[
'float16'
,
'float32'
,
'float64'
],
'batch_norm'
input
,
'input'
,
[
'uint16'
,
'float16'
,
'float32'
,
'float64'
],
'batch_norm'
,
)
)
dtype
=
helper
.
input_dtype
()
dtype
=
helper
.
input_dtype
()
# use fp32 for bn parameter
# use fp32 for bn parameter
if
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
dtype
==
core
.
VarDesc
.
VarType
.
FP16
or
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
dtype
=
core
.
VarDesc
.
VarType
.
FP32
dtype
=
core
.
VarDesc
.
VarType
.
FP32
input_shape
=
input
.
shape
input_shape
=
input
.
shape
...
...
test/auto_parallel/test_pass_sharding.py
浏览文件 @
b11f0b7a
...
@@ -36,8 +36,6 @@ class TestShardingPass(unittest.TestCase):
...
@@ -36,8 +36,6 @@ class TestShardingPass(unittest.TestCase):
+
[
+
[
"-m"
,
"-m"
,
"paddle.distributed.launch"
,
"paddle.distributed.launch"
,
"--devices"
,
"0,1"
,
"--log_dir"
,
"--log_dir"
,
tmp_dir
.
name
,
tmp_dir
.
name
,
launch_model_path
,
launch_model_path
,
...
...
test/collective/collective_allgather_api.py
浏览文件 @
b11f0b7a
...
@@ -13,8 +13,6 @@
...
@@ -13,8 +13,6 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
pickle
import
sys
import
test_collective_api_base
as
test_base
import
test_collective_api_base
as
test_base
...
@@ -148,7 +146,7 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
...
@@ -148,7 +146,7 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
out
=
exe
.
run
(
out
=
exe
.
run
(
train_prog
,
feed
=
{
'tindata'
:
indata
},
fetch_list
=
fetch_list
train_prog
,
feed
=
{
'tindata'
:
indata
},
fetch_list
=
fetch_list
)
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
)
)
test_base
.
dump_output
(
out
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/collective/collective_global_gather.py
浏览文件 @
b11f0b7a
...
@@ -13,12 +13,11 @@
...
@@ -13,12 +13,11 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
pickle
import
sys
import
numpy
as
np
import
numpy
as
np
from
legacy_test.test_collective_api_base
import
(
from
legacy_test.test_collective_api_base
import
(
TestCollectiveAPIRunnerBase
,
TestCollectiveAPIRunnerBase
,
dump_output
,
runtime_main
,
runtime_main
,
)
)
...
@@ -124,7 +123,7 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase):
...
@@ -124,7 +123,7 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase):
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
)
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
)
)
dump_output
(
out
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/collective/collective_global_scatter.py
浏览文件 @
b11f0b7a
...
@@ -13,12 +13,11 @@
...
@@ -13,12 +13,11 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
pickle
import
sys
import
numpy
as
np
import
numpy
as
np
from
legacy_test.test_collective_api_base
import
(
from
legacy_test.test_collective_api_base
import
(
TestCollectiveAPIRunnerBase
,
TestCollectiveAPIRunnerBase
,
dump_output
,
runtime_main
,
runtime_main
,
)
)
...
@@ -103,7 +102,7 @@ class TestCollectiveGlobalScatterAPI(TestCollectiveAPIRunnerBase):
...
@@ -103,7 +102,7 @@ class TestCollectiveGlobalScatterAPI(TestCollectiveAPIRunnerBase):
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
)
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
)
)
dump_output
(
out
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/collective/fleet/parallel_dygraph_no_sync.py
浏览文件 @
b11f0b7a
...
@@ -18,8 +18,8 @@ import random
...
@@ -18,8 +18,8 @@ import random
import
numpy
as
np
import
numpy
as
np
from
legacy_test.test_dist_base
import
(
from
legacy_test.test_dist_base
import
(
TestParallelDyGraphRunnerBase
,
TestParallelDyGraphRunnerBase
,
dump_output
,
print_to_err
,
print_to_err
,
print_to_out
,
runtime_main
,
runtime_main
,
)
)
...
@@ -92,7 +92,7 @@ class TestNoSync(TestParallelDyGraphRunnerBase):
...
@@ -92,7 +92,7 @@ class TestNoSync(TestParallelDyGraphRunnerBase):
)
)
print_to_err
(
type
(
self
).
__name__
,
"model built in dygraph"
)
print_to_err
(
type
(
self
).
__name__
,
"model built in dygraph"
)
out_losses
=
self
.
model_train
(
args
,
model
,
opt
,
train_reader
)
out_losses
=
self
.
model_train
(
args
,
model
,
opt
,
train_reader
)
print_to_o
ut
(
out_losses
)
dump_outp
ut
(
out_losses
)
return
out_losses
return
out_losses
def
run_trainer_with_spawn_func
(
self
,
args
):
def
run_trainer_with_spawn_func
(
self
,
args
):
...
@@ -120,7 +120,7 @@ class TestNoSync(TestParallelDyGraphRunnerBase):
...
@@ -120,7 +120,7 @@ class TestNoSync(TestParallelDyGraphRunnerBase):
)
)
out_losses
=
self
.
model_train
(
args
,
model
,
opt
,
train_reader
)
out_losses
=
self
.
model_train
(
args
,
model
,
opt
,
train_reader
)
print_to_o
ut
(
out_losses
)
dump_outp
ut
(
out_losses
)
return
out_losses
return
out_losses
def
model_train
(
self
,
args
,
model
,
opt
,
train_reader
):
def
model_train
(
self
,
args
,
model
,
opt
,
train_reader
):
...
...
test/collective/test_collective_reduce_api.py
浏览文件 @
b11f0b7a
...
@@ -67,7 +67,6 @@ class TestCollectiveReduceAPI(TestDistBase):
...
@@ -67,7 +67,6 @@ class TestCollectiveReduceAPI(TestDistBase):
def
test_reduce_gloo_with_comm_context
(
self
):
def
test_reduce_gloo_with_comm_context
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
"float16"
,
"float32"
,
"float32"
,
"float64"
,
"float64"
,
"int32"
,
"int32"
,
...
@@ -115,7 +114,6 @@ class TestCollectiveReduceAPI(TestDistBase):
...
@@ -115,7 +114,6 @@ class TestCollectiveReduceAPI(TestDistBase):
def
test_reduce_gloo_dygraph
(
self
):
def
test_reduce_gloo_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
"float16"
,
"float32"
,
"float32"
,
"float64"
,
"float64"
,
"int32"
,
"int32"
,
...
...
test/legacy_test/dist_sharding_save.py
浏览文件 @
b11f0b7a
...
@@ -13,8 +13,6 @@
...
@@ -13,8 +13,6 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
pickle
import
sys
from
dist_mnist
import
cnn_model
# noqa: F401
from
dist_mnist
import
cnn_model
# noqa: F401
...
@@ -29,8 +27,12 @@ fluid.default_main_program().random_seed = 1
...
@@ -29,8 +27,12 @@ fluid.default_main_program().random_seed = 1
def
runtime_main
():
def
runtime_main
():
from
test_dist_base
import
dump_output
from
paddle.distributed
import
fleet
from
paddle.distributed
import
fleet
paddle
.
enable_static
()
# model definition
# model definition
train_prog
=
paddle
.
fluid
.
Program
()
train_prog
=
paddle
.
fluid
.
Program
()
startup_prog
=
paddle
.
fluid
.
Program
()
startup_prog
=
paddle
.
fluid
.
Program
()
...
@@ -83,7 +85,7 @@ def runtime_main():
...
@@ -83,7 +85,7 @@ def runtime_main():
)
)
out_losses
=
[]
out_losses
=
[]
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
)
)
dump_output
(
out_losses
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/legacy_test/test_collective_api_base.py
浏览文件 @
b11f0b7a
...
@@ -80,6 +80,12 @@ def create_pyobject_test_data(shape=None, seed=None):
...
@@ -80,6 +80,12 @@ def create_pyobject_test_data(shape=None, seed=None):
return
[
list_data
,
dict_data
]
return
[
list_data
,
dict_data
]
def
dump_output
(
x
):
dump_file
=
os
.
environ
[
'DUMP_FILE'
]
with
open
(
dump_file
,
'wb'
)
as
f
:
pickle
.
dump
(
x
,
f
)
def
create_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
def
create_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
assert
shape
,
"Shape should be specified"
assert
shape
,
"Shape should be specified"
if
dtype
==
"float32"
or
dtype
==
"float16"
or
dtype
==
"float64"
:
if
dtype
==
"float32"
or
dtype
==
"float16"
or
dtype
==
"float64"
:
...
@@ -160,7 +166,7 @@ class TestCollectiveAPIRunnerBase:
...
@@ -160,7 +166,7 @@ class TestCollectiveAPIRunnerBase:
else
:
else
:
out
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
,
indata
)
out
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
,
indata
)
# print(out, sys.stderr)
# print(out, sys.stderr)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
)
)
dump_output
(
out
)
def
runtime_main
(
test_class
,
col_type
):
def
runtime_main
(
test_class
,
col_type
):
...
@@ -255,6 +261,13 @@ class TestDistBase(unittest.TestCase):
...
@@ -255,6 +261,13 @@ class TestDistBase(unittest.TestCase):
# update environment
# update environment
env0
.
update
(
envs
)
env0
.
update
(
envs
)
env1
.
update
(
envs
)
env1
.
update
(
envs
)
cur_pid
=
os
.
getpid
()
dump_file_0
=
f
'./out_data_0_
{
cur_pid
}
.pickled'
dump_file_1
=
f
'./out_data_1_
{
cur_pid
}
.pickled'
env0
[
'DUMP_FILE'
]
=
dump_file_0
env1
[
'DUMP_FILE'
]
=
dump_file_1
if
os
.
getenv
(
'WITH_COVERAGE'
,
'OFF'
)
==
'ON'
:
if
os
.
getenv
(
'WITH_COVERAGE'
,
'OFF'
)
==
'ON'
:
tr_cmd
=
"%s -m coverage run --branch -p %s"
tr_cmd
=
"%s -m coverage run --branch -p %s"
else
:
else
:
...
@@ -295,9 +308,16 @@ class TestDistBase(unittest.TestCase):
...
@@ -295,9 +308,16 @@ class TestDistBase(unittest.TestCase):
sys
.
stderr
.
write
(
'trainer 0 stderr file: %s
\n
'
%
f
.
read
())
sys
.
stderr
.
write
(
'trainer 0 stderr file: %s
\n
'
%
f
.
read
())
with
open
(
path1
,
"r"
)
as
f
:
with
open
(
path1
,
"r"
)
as
f
:
sys
.
stderr
.
write
(
'trainer 1 stderr file: %s
\n
'
%
f
.
read
())
sys
.
stderr
.
write
(
'trainer 1 stderr file: %s
\n
'
%
f
.
read
())
def
load_and_remove
(
path
):
with
open
(
path
,
'rb'
)
as
f
:
out
=
pickle
.
load
(
f
)
os
.
remove
(
path
)
return
out
return
(
return
(
pickle
.
loads
(
tr0_out
),
load_and_remove
(
dump_file_0
),
pickle
.
loads
(
tr1_out
),
load_and_remove
(
dump_file_1
),
tr0_proc
.
pid
,
tr0_proc
.
pid
,
tr1_proc
.
pid
,
tr1_proc
.
pid
,
)
)
...
@@ -469,7 +489,7 @@ class TestDistBase(unittest.TestCase):
...
@@ -469,7 +489,7 @@ class TestDistBase(unittest.TestCase):
elif
col_type
==
"column_parallel_linear"
:
elif
col_type
==
"column_parallel_linear"
:
result_data
=
tr0_out
[
0
]
result_data
=
tr0_out
[
0
]
np
.
random
.
seed
(
2020
)
np
.
random
.
seed
(
2020
)
weight
=
np
.
random
.
rand
(
1000
,
16
)
weight
=
np
.
random
.
rand
(
1000
,
16
)
.
astype
(
np
.
float32
)
need_result
=
np
.
matmul
(
input1
,
weight
)
need_result
=
np
.
matmul
(
input1
,
weight
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
result_data
,
need_result
,
rtol
=
1e-05
,
atol
=
1e-05
result_data
,
need_result
,
rtol
=
1e-05
,
atol
=
1e-05
...
...
test/legacy_test/test_collective_base.py
浏览文件 @
b11f0b7a
...
@@ -126,7 +126,9 @@ class TestCollectiveRunnerBase:
...
@@ -126,7 +126,9 @@ class TestCollectiveRunnerBase:
out
=
exe
.
run
(
out
=
exe
.
run
(
train_prog
,
feed
=
{
'tindata'
:
indata
},
fetch_list
=
[
result
.
name
]
train_prog
,
feed
=
{
'tindata'
:
indata
},
fetch_list
=
[
result
.
name
]
)
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
))
dump_file
=
os
.
environ
[
'DUMP_FILE'
]
with
open
(
dump_file
,
'wb'
)
as
f
:
pickle
.
dump
(
out
,
f
)
def
runtime_main
(
test_class
,
col_type
,
sub_type
):
def
runtime_main
(
test_class
,
col_type
,
sub_type
):
...
@@ -189,9 +191,17 @@ class TestDistBase(unittest.TestCase):
...
@@ -189,9 +191,17 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINER_ENDPOINTS"
:
self
.
_ps_endpoints
,
"PADDLE_TRAINER_ENDPOINTS"
:
self
.
_ps_endpoints
,
"PADDLE_CURRENT_ENDPOINT"
:
w1_ep
,
"PADDLE_CURRENT_ENDPOINT"
:
w1_ep
,
}
}
cur_pid
=
os
.
getpid
()
dump_file_0
=
f
'./out_data_0_
{
cur_pid
}
.pickled'
dump_file_1
=
f
'./out_data_1_
{
cur_pid
}
.pickled'
# update environment
# update environment
env0
.
update
(
envs
)
env0
.
update
(
envs
)
env1
.
update
(
envs
)
env1
.
update
(
envs
)
env0
[
'DUMP_FILE'
]
=
dump_file_0
env1
[
'DUMP_FILE'
]
=
dump_file_1
tr_cmd
=
"%s %s"
tr_cmd
=
"%s %s"
tr0_cmd
=
tr_cmd
%
(
self
.
_python_interp
,
model_file
)
tr0_cmd
=
tr_cmd
%
(
self
.
_python_interp
,
model_file
)
tr1_cmd
=
tr_cmd
%
(
self
.
_python_interp
,
model_file
)
tr1_cmd
=
tr_cmd
%
(
self
.
_python_interp
,
model_file
)
...
@@ -221,9 +231,16 @@ class TestDistBase(unittest.TestCase):
...
@@ -221,9 +231,16 @@ class TestDistBase(unittest.TestCase):
# close trainer file
# close trainer file
tr0_pipe
.
close
()
tr0_pipe
.
close
()
tr1_pipe
.
close
()
tr1_pipe
.
close
()
def
load_and_remove
(
path
):
with
open
(
path
,
'rb'
)
as
f
:
out
=
pickle
.
load
(
f
)
os
.
remove
(
path
)
return
out
return
(
return
(
pickle
.
loads
(
tr0_out
),
load_and_remove
(
dump_file_0
),
pickle
.
loads
(
tr1_out
),
load_and_remove
(
dump_file_1
),
tr0_proc
.
pid
,
tr0_proc
.
pid
,
tr1_proc
.
pid
,
tr1_proc
.
pid
,
)
)
...
...
test/legacy_test/test_dist_base.py
浏览文件 @
b11f0b7a
...
@@ -44,8 +44,45 @@ DEFAULT_BATCH_SIZE = 2
...
@@ -44,8 +44,45 @@ DEFAULT_BATCH_SIZE = 2
DIST_UT_PORT
=
0
DIST_UT_PORT
=
0
def
print_to_out
(
out_losses
):
def
remove_glog_envs
(
envs
):
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
if
not
envs
:
return
envs
glog_envs
=
[
'GLOG_v'
,
'GLOG_logtostderr'
,
'GLOG_vmodule'
]
envs
=
dict
(
envs
)
for
env
in
glog_envs
:
if
env
in
envs
:
del
envs
[
env
]
return
envs
def
get_dump_file
(
rank
):
return
f
"./out_dump_
{
os
.
getpid
()
}
_
{
rank
}
.pickled"
def
modify_envs
(
envs
,
rank
=
0
):
if
not
envs
:
envs
=
{}
envs
=
remove_glog_envs
(
envs
)
dump_file
=
get_dump_file
(
rank
)
envs
[
'DUMP_FILE'
]
=
dump_file
if
os
.
path
.
exists
(
dump_file
):
os
.
remove
(
dump_file
)
return
envs
def
dump_output
(
x
):
path
=
os
.
environ
[
'DUMP_FILE'
]
with
open
(
path
,
'wb'
)
as
f
:
pickle
.
dump
(
x
,
f
)
def
load_and_remove_dump_file
(
rank
=
0
):
path
=
get_dump_file
(
rank
)
with
open
(
path
,
'rb'
)
as
f
:
out
=
pickle
.
load
(
f
)
os
.
remove
(
path
)
return
out
def
print_to_err
(
class_name
,
log_str
):
def
print_to_err
(
class_name
,
log_str
):
...
@@ -210,7 +247,7 @@ class TestDistRunnerBase:
...
@@ -210,7 +247,7 @@ class TestDistRunnerBase:
data_loader
.
reset
()
data_loader
.
reset
()
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
)
)
dump_output
(
out_losses
)
def
run_use_fleet_api_20_trainer
(
self
,
args
):
def
run_use_fleet_api_20_trainer
(
self
,
args
):
"""
"""
...
@@ -291,7 +328,7 @@ class TestDistRunnerBase:
...
@@ -291,7 +328,7 @@ class TestDistRunnerBase:
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
print_to_err
(
type
(
self
).
__name__
,
f
"dist losses:
{
out_losses
}
"
)
print_to_err
(
type
(
self
).
__name__
,
f
"dist losses:
{
out_losses
}
"
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
)
)
dump_output
(
out_losses
)
def
run_use_fleet_api_trainer
(
self
,
args
):
def
run_use_fleet_api_trainer
(
self
,
args
):
assert
args
.
update_method
==
"nccl2"
or
"bkcl"
assert
args
.
update_method
==
"nccl2"
or
"bkcl"
...
@@ -386,7 +423,7 @@ class TestDistRunnerBase:
...
@@ -386,7 +423,7 @@ class TestDistRunnerBase:
print_to_err
(
type
(
self
).
__name__
,
"run step %d finished"
%
i
)
print_to_err
(
type
(
self
).
__name__
,
"run step %d finished"
%
i
)
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
)
)
dump_output
(
out_losses
)
if
args
.
save_model
:
if
args
.
save_model
:
model_save_dir
=
"/tmp"
model_save_dir
=
"/tmp"
...
@@ -628,7 +665,7 @@ class TestDistRunnerBase:
...
@@ -628,7 +665,7 @@ class TestDistRunnerBase:
# print_to_err(type(self).__name__, "out_losses")
# print_to_err(type(self).__name__, "out_losses")
sys
.
stdout
=
old_stdout
sys
.
stdout
=
old_stdout
print_to_o
ut
(
out_losses
)
dump_outp
ut
(
out_losses
)
class
TestParallelDyGraphRunnerBase
:
class
TestParallelDyGraphRunnerBase
:
...
@@ -751,7 +788,7 @@ class TestParallelDyGraphRunnerBase:
...
@@ -751,7 +788,7 @@ class TestParallelDyGraphRunnerBase:
opt
.
minimize
(
loss
)
opt
.
minimize
(
loss
)
if
not
args
.
accumulate_gradient
:
if
not
args
.
accumulate_gradient
:
model
.
clear_gradients
()
model
.
clear_gradients
()
print_to_o
ut
(
out_losses
)
dump_outp
ut
(
out_losses
)
def
run_trainer_with_spawn
(
self
,
args
):
def
run_trainer_with_spawn
(
self
,
args
):
# 1. enable dygraph
# 1. enable dygraph
...
@@ -836,7 +873,7 @@ class TestParallelDyGraphRunnerBase:
...
@@ -836,7 +873,7 @@ class TestParallelDyGraphRunnerBase:
opt
.
step
()
opt
.
step
()
if
not
args
.
accumulate_gradient
:
if
not
args
.
accumulate_gradient
:
opt
.
clear_grad
()
opt
.
clear_grad
()
print_to_o
ut
(
out_losses
)
dump_outp
ut
(
out_losses
)
def
runtime_main
(
test_class
):
def
runtime_main
(
test_class
):
...
@@ -1071,14 +1108,14 @@ class TestDistBase(unittest.TestCase):
...
@@ -1071,14 +1108,14 @@ class TestDistBase(unittest.TestCase):
ps0_cmd
.
strip
().
split
(
" "
),
ps0_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
ps0_pipe
,
stderr
=
ps0_pipe
,
env
=
required_envs
,
env
=
modify_envs
(
required_envs
)
,
)
)
print_to_err
(
type
(
self
).
__name__
,
"going to start pserver process 1"
)
print_to_err
(
type
(
self
).
__name__
,
"going to start pserver process 1"
)
ps1_proc
=
subprocess
.
Popen
(
ps1_proc
=
subprocess
.
Popen
(
ps1_cmd
.
strip
().
split
(
" "
),
ps1_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
ps1_pipe
,
stderr
=
ps1_pipe
,
env
=
required_envs
,
env
=
modify_envs
(
required_envs
)
,
)
)
return
ps0_proc
,
ps1_proc
,
ps0_pipe
,
ps1_pipe
return
ps0_proc
,
ps1_proc
,
ps0_pipe
,
ps1_pipe
...
@@ -1093,7 +1130,6 @@ class TestDistBase(unittest.TestCase):
...
@@ -1093,7 +1130,6 @@ class TestDistBase(unittest.TestCase):
log_name
=
""
,
log_name
=
""
,
devices
=
"1"
,
devices
=
"1"
,
):
):
cmd
=
self
.
_python_interp
cmd
=
self
.
_python_interp
if
os
.
getenv
(
'WITH_COVERAGE'
,
'OFF'
)
==
'ON'
:
if
os
.
getenv
(
'WITH_COVERAGE'
,
'OFF'
)
==
'ON'
:
...
@@ -1149,14 +1185,14 @@ class TestDistBase(unittest.TestCase):
...
@@ -1149,14 +1185,14 @@ class TestDistBase(unittest.TestCase):
cmd
.
split
(
" "
),
cmd
.
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
err_log
,
stderr
=
err_log
,
env
=
env_local
,
env
=
modify_envs
(
env_local
)
,
)
)
else
:
else
:
local_proc
=
subprocess
.
Popen
(
local_proc
=
subprocess
.
Popen
(
cmd
.
split
(
" "
),
cmd
.
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
env
=
env_local
,
env
=
modify_envs
(
env_local
)
,
)
)
local_out
,
local_err
=
local_proc
.
communicate
()
local_out
,
local_err
=
local_proc
.
communicate
()
...
@@ -1165,9 +1201,8 @@ class TestDistBase(unittest.TestCase):
...
@@ -1165,9 +1201,8 @@ class TestDistBase(unittest.TestCase):
err_log
.
close
()
err_log
.
close
()
sys
.
stderr
.
write
(
'local_stderr: %s
\n
'
%
local_err
)
sys
.
stderr
.
write
(
'local_stderr: %s
\n
'
%
local_err
)
sys
.
stderr
.
write
(
'local_stdout: %s
\n
'
%
pickle
.
loads
(
local_out
))
return
pickle
.
loads
(
local_out
)
return
load_and_remove_dump_file
(
)
def
_run_local_gloo
(
def
_run_local_gloo
(
self
,
self
,
...
@@ -1259,14 +1294,14 @@ class TestDistBase(unittest.TestCase):
...
@@ -1259,14 +1294,14 @@ class TestDistBase(unittest.TestCase):
tr0_cmd
.
strip
().
split
(
" "
),
tr0_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
tr0_pipe
,
stderr
=
tr0_pipe
,
env
=
env0
,
env
=
modify_envs
(
env0
,
0
)
,
)
)
print_to_err
(
type
(
self
).
__name__
,
"going to start trainer process 1"
)
print_to_err
(
type
(
self
).
__name__
,
"going to start trainer process 1"
)
tr1_proc
=
subprocess
.
Popen
(
tr1_proc
=
subprocess
.
Popen
(
tr1_cmd
.
strip
().
split
(
" "
),
tr1_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
tr1_pipe
,
stderr
=
tr1_pipe
,
env
=
env1
,
env
=
modify_envs
(
env1
,
1
)
,
)
)
# Wait until trainer process terminate
# Wait until trainer process terminate
...
@@ -1293,7 +1328,7 @@ class TestDistBase(unittest.TestCase):
...
@@ -1293,7 +1328,7 @@ class TestDistBase(unittest.TestCase):
ps0
.
terminate
()
ps0
.
terminate
()
ps1
.
terminate
()
ps1
.
terminate
()
return
pickle
.
loads
(
tr0_out
),
pickle
.
loads
(
tr1_out
)
return
load_and_remove_dump_file
(
0
),
load_and_remove_dump_file
(
1
)
def
_get_gloo_trainer_cmd
(
def
_get_gloo_trainer_cmd
(
self
,
model
,
ep
,
update_method
,
trainer_id
,
trainer_num
self
,
model
,
ep
,
update_method
,
trainer_id
,
trainer_num
...
@@ -1337,7 +1372,6 @@ class TestDistBase(unittest.TestCase):
...
@@ -1337,7 +1372,6 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINER_ID"
:
f
"
{
trainer_id
}
"
,
"PADDLE_TRAINER_ID"
:
f
"
{
trainer_id
}
"
,
"PADDLE_TRAINER_ENDPOINTS"
:
self
.
_ps_endpoints
,
"PADDLE_TRAINER_ENDPOINTS"
:
self
.
_ps_endpoints
,
"PADDLE_CURRENT_ENDPOINT"
:
ep
,
"PADDLE_CURRENT_ENDPOINT"
:
ep
,
"PADDLE_CURRENT_ENDPOINT"
:
ep
,
"PADDLE_DISTRI_BACKEND"
:
"gloo"
,
"PADDLE_DISTRI_BACKEND"
:
"gloo"
,
"GLOG_v"
:
"2"
,
"GLOG_v"
:
"2"
,
}
}
...
@@ -1507,7 +1541,7 @@ class TestDistBase(unittest.TestCase):
...
@@ -1507,7 +1541,7 @@ class TestDistBase(unittest.TestCase):
tr_cmd
.
strip
().
split
(
" "
),
tr_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
tr_pipe
,
stderr
=
tr_pipe
,
env
=
tr_env
,
env
=
modify_envs
(
tr_env
,
i
)
,
)
)
procs
.
append
(
tr_proc
)
procs
.
append
(
tr_proc
)
...
@@ -1523,13 +1557,13 @@ class TestDistBase(unittest.TestCase):
...
@@ -1523,13 +1557,13 @@ class TestDistBase(unittest.TestCase):
if
trainer_num
==
1
:
if
trainer_num
==
1
:
if
check_error_log
:
if
check_error_log
:
print
(
"outs[0]:"
,
outs
[
0
])
print
(
"outs[0]:"
,
outs
[
0
])
return
pickle
.
loads
(
outs
[
0
]
)
return
load_and_remove_dump_file
(
0
)
else
:
else
:
if
check_error_log
:
if
check_error_log
:
print
(
"outs[0]:"
,
outs
[
0
])
print
(
"outs[0]:"
,
outs
[
0
])
print
(
"outs[1]:"
,
outs
[
1
])
print
(
"outs[1]:"
,
outs
[
1
])
return
pickle
.
loads
(
outs
[
0
]),
pickle
.
loads
(
outs
[
1
]
)
return
load_and_remove_dump_file
(
0
),
load_and_remove_dump_file
(
1
)
def
_run_cluster_nccl2
(
def
_run_cluster_nccl2
(
self
,
model
,
envs
,
update_method
,
check_error_log
,
log_name
self
,
model
,
envs
,
update_method
,
check_error_log
,
log_name
...
@@ -1581,7 +1615,7 @@ class TestDistBase(unittest.TestCase):
...
@@ -1581,7 +1615,7 @@ class TestDistBase(unittest.TestCase):
tr_cmd
.
strip
().
split
(
" "
),
tr_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
tr_pipe
,
stderr
=
tr_pipe
,
env
=
tr_env
,
env
=
modify_envs
(
tr_env
,
i
)
,
)
)
procs
.
append
(
tr_proc
)
procs
.
append
(
tr_proc
)
...
@@ -1598,7 +1632,7 @@ class TestDistBase(unittest.TestCase):
...
@@ -1598,7 +1632,7 @@ class TestDistBase(unittest.TestCase):
print
(
"outs[0]:"
,
outs
[
0
])
print
(
"outs[0]:"
,
outs
[
0
])
print
(
"outs[1]:"
,
outs
[
1
])
print
(
"outs[1]:"
,
outs
[
1
])
return
pickle
.
loads
(
outs
[
0
]),
pickle
.
loads
(
outs
[
1
]
)
return
load_and_remove_dump_file
(
0
),
load_and_remove_dump_file
(
1
)
def
_run_pipeline
(
self
,
model
,
envs
,
check_error_log
,
log_name
):
def
_run_pipeline
(
self
,
model
,
envs
,
check_error_log
,
log_name
):
# NOTE: we reuse ps_endpoints as nccl2 worker endpoints
# NOTE: we reuse ps_endpoints as nccl2 worker endpoints
...
@@ -1631,7 +1665,7 @@ class TestDistBase(unittest.TestCase):
...
@@ -1631,7 +1665,7 @@ class TestDistBase(unittest.TestCase):
tr_cmd
.
strip
().
split
(
" "
),
tr_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
tr_pipe
,
stderr
=
tr_pipe
,
env
=
tr_env
,
env
=
modify_envs
(
tr_env
,
i
)
,
)
)
procs
.
append
(
tr_proc
)
procs
.
append
(
tr_proc
)
...
@@ -1647,7 +1681,7 @@ class TestDistBase(unittest.TestCase):
...
@@ -1647,7 +1681,7 @@ class TestDistBase(unittest.TestCase):
if
check_error_log
:
if
check_error_log
:
print
(
"outs[0]:"
,
outs
[
0
])
print
(
"outs[0]:"
,
outs
[
0
])
print
(
"outs[1]:"
,
outs
[
1
])
print
(
"outs[1]:"
,
outs
[
1
])
return
pickle
.
loads
(
outs
[
0
]),
pickle
.
loads
(
outs
[
1
]
)
return
load_and_remove_dump_file
(
0
),
load_and_remove_dump_file
(
1
)
def
_get_required_envs
(
self
,
check_error_log
=
False
,
need_envs
=
{}):
def
_get_required_envs
(
self
,
check_error_log
=
False
,
need_envs
=
{}):
# TODO(typhoonzero): should auto adapt GPU count on the machine.
# TODO(typhoonzero): should auto adapt GPU count on the machine.
...
@@ -1789,7 +1823,6 @@ class TestDistBase(unittest.TestCase):
...
@@ -1789,7 +1823,6 @@ class TestDistBase(unittest.TestCase):
need_envs
=
{},
need_envs
=
{},
log_name
=
""
,
log_name
=
""
,
):
):
# need open p2p or shm otherwise multi cards mode will hang
# need open p2p or shm otherwise multi cards mode will hang
need_envs
.
update
({
"NCCL_P2P_DISABLE"
:
"0"
,
"NCCL_SHM_DISABLE"
:
"0"
})
need_envs
.
update
({
"NCCL_P2P_DISABLE"
:
"0"
,
"NCCL_SHM_DISABLE"
:
"0"
})
...
...
test/legacy_test/test_sync_batch_norm_op.py
浏览文件 @
b11f0b7a
...
@@ -18,9 +18,11 @@ for both FP64 and FP16 input.
...
@@ -18,9 +18,11 @@ for both FP64 and FP16 input.
import
os
import
os
import
random
import
random
import
subprocess
import
shutil
import
sys
import
tempfile
import
tempfile
import
unittest
import
unittest
from
shlex
import
quote
import
numpy
as
np
import
numpy
as
np
from
decorator_helper
import
prog_scope
from
decorator_helper
import
prog_scope
...
@@ -33,10 +35,41 @@ from eager_op_test import (
...
@@ -33,10 +35,41 @@ from eager_op_test import (
import
paddle
import
paddle
from
paddle
import
fluid
,
nn
from
paddle
import
fluid
,
nn
from
paddle.fluid
import
Program
,
core
,
program_guard
from
paddle.fluid
import
Program
,
core
,
program_guard
from
paddle.fluid.framework
import
in_dygraph_mode
_set_use_system_allocator
(
True
)
_set_use_system_allocator
(
True
)
def
enable_static
():
if
in_dygraph_mode
():
paddle
.
enable_static
()
def
cleanup
():
paddle
.
disable_static
()
else
:
def
cleanup
():
pass
return
cleanup
def
convert_numpy_array
(
array
):
if
array
.
dtype
!=
np
.
uint16
:
return
array
cleanup
=
None
if
not
in_dygraph_mode
():
paddle
.
disable_static
()
cleanup
=
lambda
:
paddle
.
enable_static
()
out
=
paddle
.
to_tensor
(
array
).
astype
(
paddle
.
float32
).
numpy
()
if
cleanup
is
not
None
:
cleanup
()
return
out
def
create_or_get_tensor
(
scope
,
var_name
,
var
,
place
):
def
create_or_get_tensor
(
scope
,
var_name
,
var
,
place
):
"""Get tensor, if not found, create a new one."""
"""Get tensor, if not found, create a new one."""
tensor
=
scope
.
var
(
var_name
).
get_tensor
()
tensor
=
scope
.
var
(
var_name
).
get_tensor
()
...
@@ -47,6 +80,24 @@ def create_or_get_tensor(scope, var_name, var, place):
...
@@ -47,6 +80,24 @@ def create_or_get_tensor(scope, var_name, var, place):
return
tensor
return
tensor
def
clean_dir
(
path
):
if
isinstance
(
path
,
tempfile
.
TemporaryDirectory
):
path
=
path
.
name
for
f
in
os
.
listdir
(
path
):
f
=
os
.
path
.
join
(
path
,
f
)
if
os
.
path
.
isdir
(
f
):
shutil
.
rmtree
(
f
)
else
:
os
.
remove
(
f
)
def
concat_cmd
(
cmd
):
if
isinstance
(
cmd
,
str
):
return
cmd
return
' '
.
join
([
quote
(
c
)
for
c
in
cmd
])
class
TestSyncBatchNormOpTraining
(
unittest
.
TestCase
):
class
TestSyncBatchNormOpTraining
(
unittest
.
TestCase
):
"""sync_batch_norm op test."""
"""sync_batch_norm op test."""
...
@@ -69,7 +120,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
...
@@ -69,7 +120,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
def
multi_device_run
(
self
,
layout
,
fetch_list
,
only_forward
=
False
):
def
multi_device_run
(
self
,
layout
,
fetch_list
,
only_forward
=
False
):
cmds
=
[
cmds
=
[
"python"
,
sys
.
executable
,
"-m"
,
"-m"
,
"paddle.distributed.launch"
,
"paddle.distributed.launch"
,
]
]
...
@@ -91,8 +142,8 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
...
@@ -91,8 +142,8 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
cmds
+=
[
"--only_forward"
]
cmds
+=
[
"--only_forward"
]
if
self
.
dtype
==
np
.
float16
or
self
.
dtype
==
np
.
uint16
:
if
self
.
dtype
==
np
.
float16
or
self
.
dtype
==
np
.
uint16
:
cmds
+=
[
"--use_cudnn"
]
cmds
+=
[
"--use_cudnn"
]
p
=
subprocess
.
run
(
cmds
)
cmd
=
concat_cmd
(
cmds
)
assert
p
.
returncode
==
0
,
f
"Fleet train: Failed:
{
p
}
"
assert
os
.
system
(
cmd
)
==
0
,
cmd
def
_build_program
(
def
_build_program
(
self
,
place
,
layout
,
seed
,
sync_bn
=
False
,
only_forward
=
False
self
,
place
,
layout
,
seed
,
sync_bn
=
False
,
only_forward
=
False
...
@@ -143,9 +194,18 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
...
@@ -143,9 +194,18 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
@
prog_scope
()
@
prog_scope
()
def
_compare
(
self
,
place
,
layout
,
only_forward
):
def
_compare
(
self
,
place
,
layout
,
only_forward
):
try
:
with
paddle
.
utils
.
unique_name
.
guard
():
self
.
_compare_impl
(
place
,
layout
,
only_forward
)
finally
:
clean_dir
(
self
.
data_dir
)
clean_dir
(
self
.
fleet_log_dir
)
def
_compare_impl
(
self
,
place
,
layout
,
only_forward
):
"""Compare results."""
"""Compare results."""
seed
=
10
seed
=
10
os
.
environ
[
'FLAGS_cudnn_deterministic'
]
=
"1"
os
.
environ
[
'FLAGS_cudnn_deterministic'
]
=
"1"
paddle
.
set_flags
({
'FLAGS_cudnn_deterministic'
:
1
})
paddle
.
enable_static
()
paddle
.
enable_static
()
scope
=
core
.
Scope
()
scope
=
core
.
Scope
()
if
self
.
dtype
==
np
.
uint16
:
if
self
.
dtype
==
np
.
uint16
:
...
@@ -234,8 +294,8 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
...
@@ -234,8 +294,8 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
if
sync_bn_val
.
shape
!=
bn_val
.
shape
:
if
sync_bn_val
.
shape
!=
bn_val
.
shape
:
bn_val
=
bn_val
[:
stride
]
bn_val
=
bn_val
[:
stride
]
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
bn_val
,
convert_numpy_array
(
bn_val
)
,
sync_bn_val
,
convert_numpy_array
(
sync_bn_val
)
,
rtol
=
1e-05
,
rtol
=
1e-05
,
atol
=
self
.
atol
,
atol
=
self
.
atol
,
err_msg
=
'Output ('
err_msg
=
'Output ('
...
@@ -311,6 +371,7 @@ class TestDygraphSyncBatchNormAPIError(unittest.TestCase):
...
@@ -311,6 +371,7 @@ class TestDygraphSyncBatchNormAPIError(unittest.TestCase):
if
not
core
.
is_compiled_with_cuda
():
if
not
core
.
is_compiled_with_cuda
():
return
return
cleanup
=
enable_static
()
with
program_guard
(
Program
(),
Program
()):
with
program_guard
(
Program
(),
Program
()):
my_sync_batch_norm
=
paddle
.
nn
.
SyncBatchNorm
(
10
)
my_sync_batch_norm
=
paddle
.
nn
.
SyncBatchNorm
(
10
)
x1
=
fluid
.
create_lod_tensor
(
x1
=
fluid
.
create_lod_tensor
(
...
@@ -325,6 +386,7 @@ class TestDygraphSyncBatchNormAPIError(unittest.TestCase):
...
@@ -325,6 +386,7 @@ class TestDygraphSyncBatchNormAPIError(unittest.TestCase):
)
)
x2
.
desc
.
set_need_check_feed
(
False
)
x2
.
desc
.
set_need_check_feed
(
False
)
self
.
assertRaises
(
TypeError
,
my_sync_batch_norm
,
x2
)
self
.
assertRaises
(
TypeError
,
my_sync_batch_norm
,
x2
)
cleanup
()
class
TestConvertSyncBatchNorm
(
unittest
.
TestCase
):
class
TestConvertSyncBatchNorm
(
unittest
.
TestCase
):
...
@@ -384,71 +446,6 @@ class TestConvertSyncBatchNormCast1(unittest.TestCase):
...
@@ -384,71 +446,6 @@ class TestConvertSyncBatchNormCast1(unittest.TestCase):
self
.
assertEqual
(
len
(
compare_model
.
sublayers
()),
len
(
model
.
sublayers
()))
self
.
assertEqual
(
len
(
compare_model
.
sublayers
()),
len
(
model
.
sublayers
()))
class
TestConvertSyncBatchNormCase2
(
unittest
.
TestCase
):
def
test_convert
(
self
):
if
not
core
.
is_compiled_with_cuda
():
return
with
fluid
.
dygraph
.
guard
(
fluid
.
CUDAPlace
(
0
)):
class
SyBNNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
in_ch
=
3
,
out_ch
=
3
,
dirate
=
1
):
super
().
__init__
()
self
.
bn_s1
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
paddle
.
regularizer
.
L2Decay
(
0.0
)
),
)
)
self
.
bn_s2
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
data_format
=
'NDHWC'
)
)
def
forward
(
self
,
x
):
x
=
self
.
bn_s1
(
x
)
out
=
paddle
.
sum
(
paddle
.
abs
(
self
.
bn_s2
(
x
)))
return
out
class
BNNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
in_ch
=
3
,
out_ch
=
3
,
dirate
=
1
):
super
().
__init__
()
self
.
bn_s1
=
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
paddle
.
regularizer
.
L2Decay
(
0.0
)
),
)
self
.
bn_s2
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
data_format
=
'NDHWC'
)
)
def
forward
(
self
,
x
):
x
=
self
.
bn_s1
(
x
)
out
=
paddle
.
sum
(
paddle
.
abs
(
self
.
bn_s2
(
x
)))
return
out
bn_model
=
BNNet
()
sybn_model
=
SyBNNet
()
np
.
random
.
seed
(
10
)
data
=
np
.
random
.
random
([
3
,
3
,
3
,
3
,
3
]).
astype
(
'float32'
)
x
=
paddle
.
to_tensor
(
data
)
bn_out
=
bn_model
(
x
)
sybn_out
=
sybn_model
(
x
)
np
.
testing
.
assert_allclose
(
bn_out
.
numpy
(),
sybn_out
.
numpy
(),
rtol
=
1e-05
,
err_msg
=
'Output has diff.
\n
'
+
'
\n
BN '
+
str
(
bn_out
.
numpy
())
+
'
\n
'
+
'Sync BN '
+
str
(
sybn_out
.
numpy
()),
)
class
TestDygraphSyncBatchNormDataFormatError
(
unittest
.
TestCase
):
class
TestDygraphSyncBatchNormDataFormatError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
def
test_errors
(
self
):
if
not
core
.
is_compiled_with_cuda
():
if
not
core
.
is_compiled_with_cuda
():
...
...
test/legacy_test/test_sync_batch_norm_op_convert.py
0 → 100644
浏览文件 @
b11f0b7a
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
class
SyBNNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
in_ch
=
3
,
out_ch
=
3
,
dirate
=
1
):
super
().
__init__
()
self
.
bn_s1
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
paddle
.
regularizer
.
L2Decay
(
0.0
)
),
)
)
self
.
bn_s2
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
data_format
=
'NDHWC'
)
)
def
forward
(
self
,
x
):
x
=
self
.
bn_s1
(
x
)
out
=
paddle
.
sum
(
paddle
.
abs
(
self
.
bn_s2
(
x
)))
return
out
class
BNNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
in_ch
=
3
,
out_ch
=
3
,
dirate
=
1
):
super
().
__init__
()
self
.
bn_s1
=
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
paddle
.
regularizer
.
L2Decay
(
0.0
)
),
)
self
.
bn_s2
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
data_format
=
'NDHWC'
)
)
def
forward
(
self
,
x
):
x
=
self
.
bn_s1
(
x
)
out
=
paddle
.
sum
(
paddle
.
abs
(
self
.
bn_s2
(
x
)))
return
out
class
TestConvertSyncBatchNormCase
(
unittest
.
TestCase
):
def
test_convert
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
bn_model
=
BNNet
()
sybn_model
=
SyBNNet
()
np
.
random
.
seed
(
10
)
data
=
np
.
random
.
random
([
3
,
3
,
3
,
3
,
3
]).
astype
(
'float32'
)
x
=
paddle
.
to_tensor
(
data
)
bn_out
=
bn_model
(
x
)
sybn_out
=
sybn_model
(
x
)
np
.
testing
.
assert_allclose
(
bn_out
.
numpy
(),
sybn_out
.
numpy
(),
rtol
=
1e-05
,
err_msg
=
'Output has diff.
\n
'
+
'
\n
BN '
+
str
(
bn_out
.
numpy
())
+
'
\n
'
+
'Sync BN '
+
str
(
sybn_out
.
numpy
()),
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录