Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0d12aa64
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0d12aa64
编写于
12月 20, 2021
作者:
S
sneaxiy
提交者:
GitHub
12月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add check pass conflict tools (#38276)
上级
ac696941
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
188 addition
and
56 deletion
+188
-56
python/paddle/distributed/passes/pass_base.py
python/paddle/distributed/passes/pass_base.py
+5
-1
python/paddle/fluid/tests/unittests/distributed_passes/check_pass_conflict_example.py
...ittests/distributed_passes/check_pass_conflict_example.py
+45
-0
python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py
...tests/unittests/distributed_passes/dist_pass_test_base.py
+64
-14
python/paddle/fluid/tests/unittests/distributed_passes/model_zoo.py
...dle/fluid/tests/unittests/distributed_passes/model_zoo.py
+61
-0
python/paddle/fluid/tests/unittests/distributed_passes/pass_run_main.py
...fluid/tests/unittests/distributed_passes/pass_run_main.py
+10
-1
python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_all_reduce_pass.py
...ests/distributed_passes/test_dist_fuse_all_reduce_pass.py
+3
-40
未找到文件。
python/paddle/distributed/passes/pass_base.py
浏览文件 @
0d12aa64
...
@@ -315,4 +315,8 @@ class PassManager:
...
@@ -315,4 +315,8 @@ class PassManager:
@
property
@
property
def
names
(
self
):
def
names
(
self
):
return
[
p
.
name
for
p
in
self
.
_passes
]
return
[
p
.
name
for
p
in
self
.
passes
]
@
property
def
passes
(
self
):
return
tuple
(
self
.
_passes
)
python/paddle/fluid/tests/unittests/distributed_passes/check_pass_conflict_example.py
0 → 100644
浏览文件 @
0d12aa64
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
dist_pass_test_base
import
PassConflictChecker
from
paddle.distributed.passes
import
new_pass
from
model_zoo
import
resnet_model
class
CheckPassConflictTest1
(
PassConflictChecker
):
def
pass_config
(
self
):
return
[
new_pass
(
"fuse_all_reduce"
,
{
"max_memory_size"
:
1024
*
1024
}),
new_pass
(
"fuse_elewise_add_act"
),
]
def
test_resnet
(
self
):
self
.
check_main
(
resnet_model
,
batch_size
=
32
)
class
CheckPassConflictTest2
(
PassConflictChecker
):
def
pass_config
(
self
):
return
[
new_pass
(
"fuse_elewise_add_act"
),
new_pass
(
"fuse_all_reduce"
,
{
"max_memory_size"
:
1024
*
1024
}),
]
def
test_resnet
(
self
):
with
self
.
assertRaises
(
Exception
):
self
.
check_main
(
resnet_model
,
batch_size
=
32
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py
浏览文件 @
0d12aa64
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
import
unittest
import
unittest
import
paddle
import
paddle
import
os
import
os
import
random
import
sys
import
sys
import
pickle
import
pickle
import
shlex
import
shlex
...
@@ -24,6 +23,7 @@ import inspect
...
@@ -24,6 +23,7 @@ import inspect
import
numpy
as
np
import
numpy
as
np
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
paddle.distributed.fleet.launch_utils
import
run_with_coverage
from
paddle.distributed.fleet.launch_utils
import
run_with_coverage
from
paddle.distributed.passes.pass_base
import
new_pass
,
PassBase
,
PassManager
def
prepare_python_path_and_return_module
(
path
):
def
prepare_python_path_and_return_module
(
path
):
...
@@ -58,6 +58,9 @@ def remove_path_if_exists(path):
...
@@ -58,6 +58,9 @@ def remove_path_if_exists(path):
class
DistPassTestBase
(
unittest
.
TestCase
):
class
DistPassTestBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
paddle
.
enable_static
()
paddle
.
enable_static
()
if
paddle
.
is_compiled_with_cuda
():
paddle
.
set_flags
({
'FLAGS_cudnn_deterministic'
:
1
})
seed
=
int
(
os
.
environ
.
get
(
'SEED'
,
-
1
))
seed
=
int
(
os
.
environ
.
get
(
'SEED'
,
-
1
))
if
seed
<=
0
:
if
seed
<=
0
:
seed
=
np
.
random
.
randint
(
low
=
1
,
high
=
1000000
,
size
=
[
1
])[
0
]
seed
=
np
.
random
.
randint
(
low
=
1
,
high
=
1000000
,
size
=
[
1
])[
0
]
...
@@ -80,11 +83,11 @@ class DistPassTestBase(unittest.TestCase):
...
@@ -80,11 +83,11 @@ class DistPassTestBase(unittest.TestCase):
def
apply_passes
(
self
,
main_prog
,
startup_prog
):
def
apply_passes
(
self
,
main_prog
,
startup_prog
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
check_main
(
self
,
gpus
=
None
,
**
kwargs
):
def
check_main
(
self
,
model
=
None
,
gpus
=
None
,
**
kwargs
):
no_pass_rets
=
self
.
_distributed_launch
(
no_pass_rets
=
self
.
_distributed_launch
(
apply_pass
=
Fals
e
,
gpus
=
gpus
,
**
kwargs
)
model
=
model
,
apply_pass
=
Tru
e
,
gpus
=
gpus
,
**
kwargs
)
pass_rets
=
self
.
_distributed_launch
(
pass_rets
=
self
.
_distributed_launch
(
apply_pass
=
Tru
e
,
gpus
=
gpus
,
**
kwargs
)
model
=
model
,
apply_pass
=
Fals
e
,
gpus
=
gpus
,
**
kwargs
)
self
.
check_results
(
no_pass_rets
,
pass_rets
)
self
.
check_results
(
no_pass_rets
,
pass_rets
)
def
check_results
(
self
,
no_pass_rets
,
pass_rets
):
def
check_results
(
self
,
no_pass_rets
,
pass_rets
):
...
@@ -105,7 +108,7 @@ class DistPassTestBase(unittest.TestCase):
...
@@ -105,7 +108,7 @@ class DistPassTestBase(unittest.TestCase):
equal_nan
=
self
.
equal_nan
))
equal_nan
=
self
.
equal_nan
))
@
classmethod
@
classmethod
def
_to_var_names
(
cls
,
program
,
names_or_vars
):
def
_to_var_names
(
cls
,
names_or_vars
):
if
not
isinstance
(
names_or_vars
,
(
list
,
tuple
)):
if
not
isinstance
(
names_or_vars
,
(
list
,
tuple
)):
names_or_vars
=
[
names_or_vars
]
names_or_vars
=
[
names_or_vars
]
ret_var_names
=
[]
ret_var_names
=
[]
...
@@ -116,18 +119,20 @@ class DistPassTestBase(unittest.TestCase):
...
@@ -116,18 +119,20 @@ class DistPassTestBase(unittest.TestCase):
ret_var_names
.
append
(
name_or_var
.
name
)
ret_var_names
.
append
(
name_or_var
.
name
)
return
ret_var_names
return
ret_var_names
def
_run_gpu_main
(
self
,
apply_pass
,
dump_file
,
**
kwargs
):
def
_run_gpu_main
(
self
,
model
,
apply_pass
,
dump_file
,
**
kwargs
):
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
place
=
paddle
.
CUDAPlace
(
gpu_id
)
place
=
paddle
.
CUDAPlace
(
gpu_id
)
scope
=
paddle
.
static
.
Scope
()
scope
=
paddle
.
static
.
Scope
()
if
model
is
None
:
model
=
self
.
get_model
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()):
paddle
.
static
.
Program
()):
with
paddle
.
static
.
scope_guard
(
scope
):
with
paddle
.
static
.
scope_guard
(
scope
):
with
paddle
.
fluid
.
unique_name
.
guard
():
with
paddle
.
fluid
.
unique_name
.
guard
():
main_prog
,
startup_prog
,
inputs
,
outputs
,
reader
=
self
.
get_
model
(
main_prog
,
startup_prog
,
inputs
,
outputs
,
reader
=
model
(
place
,
**
kwargs
)
place
,
**
kwargs
)
inputs
=
self
.
_to_var_names
(
main_prog
,
inputs
)
inputs
=
self
.
_to_var_names
(
inputs
)
outputs
=
self
.
_to_var_names
(
main_prog
,
outputs
)
outputs
=
self
.
_to_var_names
(
outputs
)
if
apply_pass
:
if
apply_pass
:
self
.
apply_passes
(
main_prog
,
startup_prog
)
self
.
apply_passes
(
main_prog
,
startup_prog
)
...
@@ -161,7 +166,7 @@ class DistPassTestBase(unittest.TestCase):
...
@@ -161,7 +166,7 @@ class DistPassTestBase(unittest.TestCase):
int
(
s
.
strip
())
for
s
in
visible_devices
.
split
(
","
)
if
s
.
strip
()
int
(
s
.
strip
())
for
s
in
visible_devices
.
split
(
","
)
if
s
.
strip
()
]
]
def
_distributed_launch
(
self
,
apply_pass
,
gpus
=
None
,
**
kwargs
):
def
_distributed_launch
(
self
,
model
,
apply_pass
,
gpus
=
None
,
**
kwargs
):
if
gpus
is
None
:
if
gpus
is
None
:
gpus
=
self
.
_get_default_gpu_lists
()
gpus
=
self
.
_get_default_gpu_lists
()
...
@@ -176,7 +181,9 @@ class DistPassTestBase(unittest.TestCase):
...
@@ -176,7 +181,9 @@ class DistPassTestBase(unittest.TestCase):
remove_path_if_exists
(
output_dir
)
remove_path_if_exists
(
output_dir
)
os
.
makedirs
(
output_dir
,
mode
=
777
)
os
.
makedirs
(
output_dir
,
mode
=
777
)
input_dump_file
=
os
.
path
.
join
(
output_dir
,
'inputs'
)
input_dump_file
=
os
.
path
.
join
(
output_dir
,
'inputs.bin'
)
model_dump_file
=
os
.
path
.
join
(
output_dir
,
'model.bin'
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
run_with_coverage
(
True
)
run_with_coverage
(
True
)
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
...
@@ -189,6 +196,10 @@ class DistPassTestBase(unittest.TestCase):
...
@@ -189,6 +196,10 @@ class DistPassTestBase(unittest.TestCase):
with
open
(
input_dump_file
,
'wb'
)
as
f
:
with
open
(
input_dump_file
,
'wb'
)
as
f
:
pickle
.
dump
(
kwargs
,
f
)
pickle
.
dump
(
kwargs
,
f
)
if
model
is
not
None
:
with
open
(
model_dump_file
,
'wb'
)
as
f
:
pickle
.
dump
(
model
,
f
)
cmd
=
[
cmd
=
[
sys
.
executable
,
sys
.
executable
,
"-u"
,
"-u"
,
...
@@ -208,23 +219,62 @@ class DistPassTestBase(unittest.TestCase):
...
@@ -208,23 +219,62 @@ class DistPassTestBase(unittest.TestCase):
input_dump_file
,
input_dump_file
,
"--output_dir"
,
"--output_dir"
,
output_dir
,
output_dir
,
]
+
([
"--apply_pass"
]
if
apply_pass
else
[])
]
if
apply_pass
:
cmd
+=
[
"--apply_pass"
]
if
model
is
not
None
:
cmd
+=
[
"--model_file"
,
model_dump_file
]
cmd
=
[
shlex
.
quote
(
c
)
for
c
in
cmd
]
cmd
=
[
shlex
.
quote
(
c
)
for
c
in
cmd
]
prepare_python_path_and_return_module
(
__file__
)
prepare_python_path_and_return_module
(
__file__
)
exitcode
=
os
.
system
(
' '
.
join
(
cmd
))
exitcode
=
os
.
system
(
' '
.
join
(
cmd
))
self
.
assertEqual
(
self
.
assertEqual
(
exitcode
,
0
,
exitcode
,
0
,
"Pass failed with apply_pass = {}"
.
format
(
apply_pass
))
"Pass test failed with apply_pass = {}, please view log in {}"
.
format
(
apply_pass
,
output_dir
))
results
=
[]
results
=
[]
for
i
in
range
(
num_gpus
):
for
i
in
range
(
num_gpus
):
dump_file
=
'{0}/{1}.bin'
.
format
(
output_dir
,
i
)
dump_file
=
'{0}/{1}.bin'
.
format
(
output_dir
,
i
)
self
.
assertTrue
(
self
.
assertTrue
(
os
.
path
.
exists
(
dump_file
),
os
.
path
.
exists
(
dump_file
),
"Pass failed with apply_pass = {}"
.
format
(
apply_pass
))
"Pass test failed with apply_pass = {}, please view log in {}"
.
format
(
apply_pass
,
output_dir
))
with
open
(
dump_file
,
"rb"
)
as
f
:
with
open
(
dump_file
,
"rb"
)
as
f
:
results
.
append
(
pickle
.
load
(
f
))
results
.
append
(
pickle
.
load
(
f
))
return
results
return
results
finally
:
finally
:
if
int
(
os
.
environ
.
get
(
"DEBUG"
,
0
))
==
0
:
if
int
(
os
.
environ
.
get
(
"DEBUG"
,
0
))
==
0
:
remove_path_if_exists
(
output_dir
)
remove_path_if_exists
(
output_dir
)
class
PassConflictChecker
(
DistPassTestBase
):
def
setUp
(
self
):
os
.
environ
[
'DEBUG'
]
=
'1'
# to save the debug directory
super
(
PassConflictChecker
,
self
).
setUp
()
def
pass_config
(
self
):
raise
NotImplementedError
()
def
apply_passes
(
self
,
main_prog
,
startup_prog
):
passes
=
self
.
pass_config
()
if
not
isinstance
(
passes
,
(
list
,
tuple
)):
passes
=
[
passes
]
for
p
in
passes
:
self
.
assertTrue
(
isinstance
(
p
,
PassBase
))
auto_pass_manager
=
PassManager
(
passes
,
auto_solve_conflict
=
True
)
new_passes
=
auto_pass_manager
.
passes
self
.
assertEqual
(
len
(
passes
),
len
(
new_passes
),
"After solving conflicts, the left passes are: {}"
.
format
(
auto_pass_manager
.
names
))
for
i
,
(
p1
,
p2
)
in
enumerate
(
zip
(
passes
,
new_passes
)):
self
.
assertEqual
(
id
(
p1
),
id
(
p2
),
"After solving conflicts, the {}-th pass is different: {} vs {}"
.
format
(
i
,
p1
.
name
,
p2
.
name
))
auto_pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
python/paddle/fluid/tests/unittests/distributed_passes/model_zoo.py
0 → 100644
浏览文件 @
0d12aa64
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
paddle.distributed.fleet
as
fleet
from
paddle.vision.models
import
resnet50
as
resnet
import
numpy
as
np
import
paddle.nn
as
nn
__all__
=
[
'resnet_model'
,
]
def
get_seed_from_env
():
return
int
(
os
.
environ
.
get
(
"SEED"
,
0
))
def
resnet_model
(
place
,
batch_size
,
image_shape
=
[
3
,
224
,
224
],
num_classes
=
1000
):
image
=
paddle
.
static
.
data
(
shape
=
[
batch_size
]
+
image_shape
,
dtype
=
'float32'
,
name
=
'image'
)
label
=
paddle
.
static
.
data
(
shape
=
[
batch_size
,
1
],
dtype
=
'int64'
,
name
=
'label'
)
model
=
resnet
(
pretrained
=
False
)
loss_fn
=
nn
.
loss
.
CrossEntropyLoss
()
pred_out
=
model
(
image
)
loss
=
loss_fn
(
pred_out
,
label
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
1e-3
)
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
fuse_all_reduce_ops
=
False
dist_strategy
.
without_graph_optimization
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
optimizer
.
minimize
(
loss
)
rank
=
paddle
.
distributed
.
get_rank
()
def
reader
():
seed
=
get_seed_from_env
()
np
.
random
.
seed
(
seed
+
rank
)
for
_
in
range
(
10
):
image_np
=
np
.
random
.
random
(
size
=
image
.
shape
).
astype
(
'float32'
)
label_np
=
np
.
random
.
randint
(
low
=
0
,
high
=
num_classes
,
size
=
label
.
shape
).
astype
(
'int64'
)
yield
image_np
,
label_np
main_program
=
paddle
.
static
.
default_main_program
()
startup_program
=
paddle
.
static
.
default_startup_program
()
return
main_program
,
startup_program
,
[
image
,
label
],
[
loss
],
reader
python/paddle/fluid/tests/unittests/distributed_passes/pass_run_main.py
浏览文件 @
0d12aa64
...
@@ -44,6 +44,10 @@ def parse_args():
...
@@ -44,6 +44,10 @@ def parse_args():
'--output_dir'
,
'--output_dir'
,
type
=
str
,
type
=
str
,
help
=
'The output directory to save the logs and output results.'
)
help
=
'The output directory to save the logs and output results.'
)
parser
.
add_argument
(
'--model_file'
,
type
=
str
,
help
=
'The input model file which contains the dumped model function.'
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -60,11 +64,16 @@ def run_main(args):
...
@@ -60,11 +64,16 @@ def run_main(args):
kwargs
=
pickle
.
load
(
f
)
kwargs
=
pickle
.
load
(
f
)
output_file
=
"{}/{}.bin"
.
format
(
args
.
output_dir
,
rank
)
output_file
=
"{}/{}.bin"
.
format
(
args
.
output_dir
,
rank
)
if
args
.
model_file
:
with
open
(
args
.
model_file
,
"rb"
)
as
f
:
model
=
pickle
.
load
(
f
)
else
:
model
=
None
try
:
try
:
test_obj
.
setUpClass
()
test_obj
.
setUpClass
()
test_obj
.
setUp
()
test_obj
.
setUp
()
test_obj
.
_run_gpu_main
(
args
.
apply_pass
,
output_file
,
**
kwargs
)
test_obj
.
_run_gpu_main
(
model
,
args
.
apply_pass
,
output_file
,
**
kwargs
)
finally
:
finally
:
test_obj
.
tearDown
()
test_obj
.
tearDown
()
test_obj
.
tearDownClass
()
test_obj
.
tearDownClass
()
...
...
python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_all_reduce_pass.py
浏览文件 @
0d12aa64
...
@@ -12,20 +12,14 @@
...
@@ -12,20 +12,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
paddle
from
paddle.distributed.passes
import
new_pass
,
PassManager
from
paddle.distributed.passes
import
new_pass
,
PassManager
import
paddle.distributed.fleet
as
fleet
from
paddle.vision.models
import
resnet50
as
resnet
import
unittest
import
unittest
from
dist_pass_test_base
import
DistPassTestBase
from
dist_pass_test_base
import
DistPassTestBase
import
paddle.nn
as
nn
from
model_zoo
import
resnet_model
import
numpy
as
np
class
TestFuseAllReducePass
(
DistPassTestBase
):
class
TestFuseAllReducePass
(
DistPassTestBase
):
def
init
(
self
):
def
init
(
self
):
if
paddle
.
is_compiled_with_cuda
():
paddle
.
set_flags
({
'FLAGS_cudnn_deterministic'
:
1
})
self
.
atol
=
0.0
self
.
atol
=
0.0
self
.
rtol
=
0.0
self
.
rtol
=
0.0
...
@@ -35,41 +29,10 @@ class TestFuseAllReducePass(DistPassTestBase):
...
@@ -35,41 +29,10 @@ class TestFuseAllReducePass(DistPassTestBase):
new_pass
(
"fuse_all_reduce"
,
{
"max_memory_size"
:
1024
*
1024
})
new_pass
(
"fuse_all_reduce"
,
{
"max_memory_size"
:
1024
*
1024
})
])
])
pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
print
(
pass_manager
.
names
)
def
test_bs_32
(
self
):
def
test_bs_32
(
self
):
self
.
check_main
(
batch_size
=
32
)
self
.
check_main
(
resnet_model
,
batch_size
=
32
)
def
get_model
(
self
,
place
,
batch_size
):
image
=
paddle
.
static
.
data
(
shape
=
[
batch_size
,
3
,
224
,
224
],
dtype
=
'float32'
,
name
=
'image'
)
label
=
paddle
.
static
.
data
(
shape
=
[
batch_size
,
1
],
dtype
=
'int64'
,
name
=
'label'
)
model
=
resnet
(
pretrained
=
False
)
loss_fn
=
nn
.
loss
.
CrossEntropyLoss
()
pred_out
=
model
(
image
)
loss
=
loss_fn
(
pred_out
,
label
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
1e-3
)
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
fuse_all_reduce_ops
=
False
dist_strategy
.
without_graph_optimization
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
optimizer
.
minimize
(
loss
)
rank
=
paddle
.
distributed
.
get_rank
()
def
reader
():
np
.
random
.
seed
(
self
.
seed
+
rank
)
for
_
in
range
(
10
):
image_np
=
np
.
random
.
random
(
size
=
image
.
shape
).
astype
(
'float32'
)
label_np
=
np
.
random
.
randint
(
low
=
0
,
high
=
1000
,
size
=
label
.
shape
).
astype
(
'int64'
)
yield
image_np
,
label_np
main_program
=
paddle
.
static
.
default_main_program
()
startup_program
=
paddle
.
static
.
default_startup_program
()
return
main_program
,
startup_program
,
[
image
,
label
],
[
loss
],
reader
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录