Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0443b480
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0443b480
编写于
9月 07, 2020
作者:
D
Dong Daxiang
提交者:
GitHub
9月 07, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【paddle.fleet】add auto parallel L1 implementations (#27090)
* add auto parallel L1 implementation test=develop
上级
5af81f83
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
257 addition
and
4 deletion
+257
-4
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+127
-0
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+13
-0
python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
...paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
+11
-0
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
...paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
+4
-0
python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py
...ributed/fleet/meta_optimizers/gradient_merge_optimizer.py
+4
-0
python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py
...ibuted/fleet/meta_optimizers/graph_execution_optimizer.py
+5
-4
python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py
...addle/distributed/fleet/meta_optimizers/lamb_optimizer.py
+7
-0
python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py
...addle/distributed/fleet/meta_optimizers/lars_optimizer.py
+7
-0
python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py
...e/distributed/fleet/meta_optimizers/localsgd_optimizer.py
+4
-0
python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py
.../distributed/fleet/meta_optimizers/meta_optimizer_base.py
+4
-0
python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py
...fleet/meta_optimizers/parameter_server_graph_optimizer.py
+5
-0
python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py
...buted/fleet/meta_optimizers/parameter_server_optimizer.py
+5
-0
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
...e/distributed/fleet/meta_optimizers/pipeline_optimizer.py
+4
-0
python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py
.../distributed/fleet/meta_optimizers/recompute_optimizer.py
+4
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/test_fleet_auto.py
python/paddle/fluid/tests/unittests/test_fleet_auto.py
+51
-0
未找到文件。
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
0443b480
...
@@ -15,10 +15,25 @@
...
@@ -15,10 +15,25 @@
import
paddle
import
paddle
from
paddle.distributed.fleet.proto
import
distributed_strategy_pb2
from
paddle.distributed.fleet.proto
import
distributed_strategy_pb2
from
paddle.fluid.framework
import
Variable
,
set_flags
,
core
from
paddle.fluid.framework
import
Variable
,
set_flags
,
core
from
paddle.fluid.wrapped_decorator
import
wrap_decorator
import
google.protobuf.text_format
import
google.protobuf.text_format
__all__
=
[
"DistributedStrategy"
]
__all__
=
[
"DistributedStrategy"
]
non_auto_func_called
=
True
def
__non_auto_func_called__
(
func
):
def
__impl__
(
*
args
,
**
kwargs
):
global
non_auto_func_called
non_auto_func_called
=
False
return
func
(
*
args
,
**
kwargs
)
return
__impl__
is_strict_auto
=
wrap_decorator
(
__non_auto_func_called__
)
def
get_msg_dict
(
msg
):
def
get_msg_dict
(
msg
):
res_dict
=
{}
res_dict
=
{}
...
@@ -164,6 +179,7 @@ class DistributedStrategy(object):
...
@@ -164,6 +179,7 @@ class DistributedStrategy(object):
return
execution_strategy
return
execution_strategy
@
execution_strategy
.
setter
@
execution_strategy
.
setter
@
is_strict_auto
def
execution_strategy
(
self
,
strategy
):
def
execution_strategy
(
self
,
strategy
):
fields
=
self
.
strategy
.
execution_strategy
.
DESCRIPTOR
.
fields
fields
=
self
.
strategy
.
execution_strategy
.
DESCRIPTOR
.
fields
for
f
in
fields
:
for
f
in
fields
:
...
@@ -203,6 +219,7 @@ class DistributedStrategy(object):
...
@@ -203,6 +219,7 @@ class DistributedStrategy(object):
return
build_strategy
return
build_strategy
@
build_strategy
.
setter
@
build_strategy
.
setter
@
is_strict_auto
def
build_strategy
(
self
,
strategy
):
def
build_strategy
(
self
,
strategy
):
fields
=
self
.
strategy
.
build_strategy
.
DESCRIPTOR
.
fields
fields
=
self
.
strategy
.
build_strategy
.
DESCRIPTOR
.
fields
for
f
in
fields
:
for
f
in
fields
:
...
@@ -237,6 +254,7 @@ class DistributedStrategy(object):
...
@@ -237,6 +254,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
a_sync
return
self
.
strategy
.
a_sync
@
a_sync
.
setter
@
a_sync
.
setter
@
is_strict_auto
def
a_sync
(
self
,
flag
):
def
a_sync
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
a_sync
=
flag
self
.
strategy
.
a_sync
=
flag
...
@@ -287,6 +305,7 @@ class DistributedStrategy(object):
...
@@ -287,6 +305,7 @@ class DistributedStrategy(object):
return
get_msg_dict
(
self
.
strategy
.
a_sync_configs
)
return
get_msg_dict
(
self
.
strategy
.
a_sync_configs
)
@
a_sync_configs
.
setter
@
a_sync_configs
.
setter
@
is_strict_auto
def
a_sync_configs
(
self
,
configs
):
def
a_sync_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
a_sync_configs
,
configs
,
check_configs_key
(
self
.
strategy
.
a_sync_configs
,
configs
,
"a_sync_configs"
)
"a_sync_configs"
)
...
@@ -309,6 +328,7 @@ class DistributedStrategy(object):
...
@@ -309,6 +328,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
amp
return
self
.
strategy
.
amp
@
amp
.
setter
@
amp
.
setter
@
is_strict_auto
def
amp
(
self
,
flag
):
def
amp
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
amp
=
flag
self
.
strategy
.
amp
=
flag
...
@@ -351,6 +371,7 @@ class DistributedStrategy(object):
...
@@ -351,6 +371,7 @@ class DistributedStrategy(object):
return
get_msg_dict
(
self
.
strategy
.
amp_configs
)
return
get_msg_dict
(
self
.
strategy
.
amp_configs
)
@
amp_configs
.
setter
@
amp_configs
.
setter
@
is_strict_auto
def
amp_configs
(
self
,
configs
):
def
amp_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
amp_configs
,
configs
,
"amp_configs"
)
check_configs_key
(
self
.
strategy
.
amp_configs
,
configs
,
"amp_configs"
)
assign_configs_value
(
self
.
strategy
.
amp_configs
,
configs
)
assign_configs_value
(
self
.
strategy
.
amp_configs
,
configs
)
...
@@ -388,6 +409,7 @@ class DistributedStrategy(object):
...
@@ -388,6 +409,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
sync_nccl_allreduce
return
self
.
strategy
.
sync_nccl_allreduce
@
sync_nccl_allreduce
.
setter
@
sync_nccl_allreduce
.
setter
@
is_strict_auto
def
sync_nccl_allreduce
(
self
,
flag
):
def
sync_nccl_allreduce
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
sync_nccl_allreduce
=
flag
self
.
strategy
.
sync_nccl_allreduce
=
flag
...
@@ -411,6 +433,7 @@ class DistributedStrategy(object):
...
@@ -411,6 +433,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
use_hierarchical_allreduce
return
self
.
strategy
.
use_hierarchical_allreduce
@
use_hierarchical_allreduce
.
setter
@
use_hierarchical_allreduce
.
setter
@
is_strict_auto
def
use_hierarchical_allreduce
(
self
,
flag
):
def
use_hierarchical_allreduce
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
use_hierarchical_allreduce
=
flag
self
.
strategy
.
use_hierarchical_allreduce
=
flag
...
@@ -435,6 +458,7 @@ class DistributedStrategy(object):
...
@@ -435,6 +458,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
hierarchical_allreduce_inter_nranks
return
self
.
strategy
.
hierarchical_allreduce_inter_nranks
@
hierarchical_allreduce_inter_nranks
.
setter
@
hierarchical_allreduce_inter_nranks
.
setter
@
is_strict_auto
def
hierarchical_allreduce_inter_nranks
(
self
,
value
):
def
hierarchical_allreduce_inter_nranks
(
self
,
value
):
if
isinstance
(
value
,
int
):
if
isinstance
(
value
,
int
):
self
.
strategy
.
hierarchical_allreduce_inter_nranks
=
value
self
.
strategy
.
hierarchical_allreduce_inter_nranks
=
value
...
@@ -461,6 +485,7 @@ class DistributedStrategy(object):
...
@@ -461,6 +485,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
sync_batch_norm
return
self
.
strategy
.
sync_batch_norm
@
sync_batch_norm
.
setter
@
sync_batch_norm
.
setter
@
is_strict_auto
def
sync_batch_norm
(
self
,
flag
):
def
sync_batch_norm
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
sync_batch_norm
=
flag
self
.
strategy
.
sync_batch_norm
=
flag
...
@@ -483,6 +508,7 @@ class DistributedStrategy(object):
...
@@ -483,6 +508,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
fuse_all_reduce_ops
return
self
.
strategy
.
fuse_all_reduce_ops
@
fuse_all_reduce_ops
.
setter
@
fuse_all_reduce_ops
.
setter
@
is_strict_auto
def
fuse_all_reduce_ops
(
self
,
flag
):
def
fuse_all_reduce_ops
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
fuse_all_reduce_ops
=
flag
self
.
strategy
.
fuse_all_reduce_ops
=
flag
...
@@ -506,6 +532,7 @@ class DistributedStrategy(object):
...
@@ -506,6 +532,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
fuse_grad_size_in_MB
return
self
.
strategy
.
fuse_grad_size_in_MB
@
fuse_grad_size_in_MB
.
setter
@
fuse_grad_size_in_MB
.
setter
@
is_strict_auto
def
fuse_grad_size_in_MB
(
self
,
value
):
def
fuse_grad_size_in_MB
(
self
,
value
):
if
isinstance
(
value
,
int
):
if
isinstance
(
value
,
int
):
self
.
strategy
.
fuse_grad_size_in_MB
=
value
self
.
strategy
.
fuse_grad_size_in_MB
=
value
...
@@ -517,6 +544,7 @@ class DistributedStrategy(object):
...
@@ -517,6 +544,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
fuse_grad_size_in_TFLOPS
return
self
.
strategy
.
fuse_grad_size_in_TFLOPS
@
_fuse_grad_size_in_TFLOPS
.
setter
@
_fuse_grad_size_in_TFLOPS
.
setter
@
is_strict_auto
def
_fuse_grad_size_in_TFLOPS
(
self
,
value
):
def
_fuse_grad_size_in_TFLOPS
(
self
,
value
):
if
isinstance
(
value
,
float
):
if
isinstance
(
value
,
float
):
self
.
strategy
.
fuse_grad_size_in_TFLOPS
=
value
self
.
strategy
.
fuse_grad_size_in_TFLOPS
=
value
...
@@ -543,6 +571,7 @@ class DistributedStrategy(object):
...
@@ -543,6 +571,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
nccl_comm_num
return
self
.
strategy
.
nccl_comm_num
@
nccl_comm_num
.
setter
@
nccl_comm_num
.
setter
@
is_strict_auto
def
nccl_comm_num
(
self
,
value
):
def
nccl_comm_num
(
self
,
value
):
if
isinstance
(
value
,
int
):
if
isinstance
(
value
,
int
):
self
.
strategy
.
nccl_comm_num
=
value
self
.
strategy
.
nccl_comm_num
=
value
...
@@ -550,6 +579,7 @@ class DistributedStrategy(object):
...
@@ -550,6 +579,7 @@ class DistributedStrategy(object):
print
(
"WARNING: nccl_comm_num should have value of int type"
)
print
(
"WARNING: nccl_comm_num should have value of int type"
)
@
recompute
.
setter
@
recompute
.
setter
@
is_strict_auto
def
recompute
(
self
,
flag
):
def
recompute
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
recompute
=
flag
self
.
strategy
.
recompute
=
flag
...
@@ -574,6 +604,7 @@ class DistributedStrategy(object):
...
@@ -574,6 +604,7 @@ class DistributedStrategy(object):
return
get_msg_dict
(
self
.
strategy
.
recompute_configs
)
return
get_msg_dict
(
self
.
strategy
.
recompute_configs
)
@
recompute_configs
.
setter
@
recompute_configs
.
setter
@
is_strict_auto
def
recompute_configs
(
self
,
configs
):
def
recompute_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
recompute_configs
,
configs
,
check_configs_key
(
self
.
strategy
.
recompute_configs
,
configs
,
"checkpoint_configs"
)
"checkpoint_configs"
)
...
@@ -598,6 +629,7 @@ class DistributedStrategy(object):
...
@@ -598,6 +629,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
pipeline
return
self
.
strategy
.
pipeline
@
pipeline
.
setter
@
pipeline
.
setter
@
is_strict_auto
def
pipeline
(
self
,
flag
):
def
pipeline
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
pipeline
=
flag
self
.
strategy
.
pipeline
=
flag
...
@@ -634,6 +666,7 @@ class DistributedStrategy(object):
...
@@ -634,6 +666,7 @@ class DistributedStrategy(object):
return
get_msg_dict
(
self
.
strategy
.
pipeline_configs
)
return
get_msg_dict
(
self
.
strategy
.
pipeline_configs
)
@
pipeline_configs
.
setter
@
pipeline_configs
.
setter
@
is_strict_auto
def
pipeline_configs
(
self
,
configs
):
def
pipeline_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
pipeline_configs
,
configs
,
check_configs_key
(
self
.
strategy
.
pipeline_configs
,
configs
,
"pipeline_configs"
)
"pipeline_configs"
)
...
@@ -658,6 +691,7 @@ class DistributedStrategy(object):
...
@@ -658,6 +691,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
localsgd
return
self
.
strategy
.
localsgd
@
localsgd
.
setter
@
localsgd
.
setter
@
is_strict_auto
def
localsgd
(
self
,
flag
):
def
localsgd
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
localsgd
=
flag
self
.
strategy
.
localsgd
=
flag
...
@@ -690,6 +724,7 @@ class DistributedStrategy(object):
...
@@ -690,6 +724,7 @@ class DistributedStrategy(object):
return
get_msg_dict
(
self
.
strategy
.
localsgd_configs
)
return
get_msg_dict
(
self
.
strategy
.
localsgd_configs
)
@
localsgd_configs
.
setter
@
localsgd_configs
.
setter
@
is_strict_auto
def
localsgd_configs
(
self
,
configs
):
def
localsgd_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
localsgd_configs
,
configs
,
check_configs_key
(
self
.
strategy
.
localsgd_configs
,
configs
,
"localsgd_configs"
)
"localsgd_configs"
)
...
@@ -714,6 +749,7 @@ class DistributedStrategy(object):
...
@@ -714,6 +749,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
dgc
return
self
.
strategy
.
dgc
@
dgc
.
setter
@
dgc
.
setter
@
is_strict_auto
def
dgc
(
self
,
flag
):
def
dgc
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
dgc
=
flag
self
.
strategy
.
dgc
=
flag
...
@@ -749,6 +785,7 @@ class DistributedStrategy(object):
...
@@ -749,6 +785,7 @@ class DistributedStrategy(object):
return
get_msg_dict
(
self
.
strategy
.
dgc_configs
)
return
get_msg_dict
(
self
.
strategy
.
dgc_configs
)
@
dgc_configs
.
setter
@
dgc_configs
.
setter
@
is_strict_auto
def
dgc_configs
(
self
,
configs
):
def
dgc_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
dgc_configs
,
configs
,
"dgc_configs"
)
check_configs_key
(
self
.
strategy
.
dgc_configs
,
configs
,
"dgc_configs"
)
assign_configs_value
(
self
.
strategy
.
dgc_configs
,
configs
)
assign_configs_value
(
self
.
strategy
.
dgc_configs
,
configs
)
...
@@ -776,6 +813,7 @@ class DistributedStrategy(object):
...
@@ -776,6 +813,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
gradient_merge
return
self
.
strategy
.
gradient_merge
@
gradient_merge
.
setter
@
gradient_merge
.
setter
@
is_strict_auto
def
gradient_merge
(
self
,
flag
):
def
gradient_merge
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
gradient_merge
=
flag
self
.
strategy
.
gradient_merge
=
flag
...
@@ -803,6 +841,7 @@ class DistributedStrategy(object):
...
@@ -803,6 +841,7 @@ class DistributedStrategy(object):
return
get_msg_dict
(
self
.
strategy
.
gradient_merge_configs
)
return
get_msg_dict
(
self
.
strategy
.
gradient_merge_configs
)
@
gradient_merge_configs
.
setter
@
gradient_merge_configs
.
setter
@
is_strict_auto
def
gradient_merge_configs
(
self
,
configs
):
def
gradient_merge_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
gradient_merge_configs
,
configs
,
check_configs_key
(
self
.
strategy
.
gradient_merge_configs
,
configs
,
"gradient_configs"
)
"gradient_configs"
)
...
@@ -827,6 +866,7 @@ class DistributedStrategy(object):
...
@@ -827,6 +866,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
lars
return
self
.
strategy
.
lars
@
lars
.
setter
@
lars
.
setter
@
is_strict_auto
def
lars
(
self
,
flag
):
def
lars
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
lars
=
flag
self
.
strategy
.
lars
=
flag
...
@@ -862,6 +902,7 @@ class DistributedStrategy(object):
...
@@ -862,6 +902,7 @@ class DistributedStrategy(object):
return
get_msg_dict
(
self
.
strategy
.
lars_configs
)
return
get_msg_dict
(
self
.
strategy
.
lars_configs
)
@
lars_configs
.
setter
@
lars_configs
.
setter
@
is_strict_auto
def
lars_configs
(
self
,
configs
):
def
lars_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
lars_configs
,
configs
,
"lars_configs"
)
check_configs_key
(
self
.
strategy
.
lars_configs
,
configs
,
"lars_configs"
)
assign_configs_value
(
self
.
strategy
.
lars_configs
,
configs
)
assign_configs_value
(
self
.
strategy
.
lars_configs
,
configs
)
...
@@ -887,6 +928,7 @@ class DistributedStrategy(object):
...
@@ -887,6 +928,7 @@ class DistributedStrategy(object):
return
self
.
strategy
.
lamb
return
self
.
strategy
.
lamb
@
lamb
.
setter
@
lamb
.
setter
@
is_strict_auto
def
lamb
(
self
,
flag
):
def
lamb
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
lamb
=
flag
self
.
strategy
.
lamb
=
flag
...
@@ -917,15 +959,21 @@ class DistributedStrategy(object):
...
@@ -917,15 +959,21 @@ class DistributedStrategy(object):
return
get_msg_dict
(
self
.
strategy
.
lamb_configs
)
return
get_msg_dict
(
self
.
strategy
.
lamb_configs
)
@
lamb_configs
.
setter
@
lamb_configs
.
setter
@
is_strict_auto
def
lamb_configs
(
self
,
configs
):
def
lamb_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
lamb_configs
,
configs
,
"lamb_configs"
)
check_configs_key
(
self
.
strategy
.
lamb_configs
,
configs
,
"lamb_configs"
)
assign_configs_value
(
self
.
strategy
.
lamb_configs
,
configs
)
assign_configs_value
(
self
.
strategy
.
lamb_configs
,
configs
)
@
property
@
property
def
elastic
(
self
):
def
elastic
(
self
):
"""
Indicating whether we want to do current distributed training on clusters with elastic resources.
Currently, this is configuration is not valid.
"""
return
self
.
strategy
.
elastic
return
self
.
strategy
.
elastic
@
elastic
.
setter
@
elastic
.
setter
@
is_strict_auto
def
elastic
(
self
,
flag
):
def
elastic
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
elastic
=
flag
self
.
strategy
.
elastic
=
flag
...
@@ -934,6 +982,25 @@ class DistributedStrategy(object):
...
@@ -934,6 +982,25 @@ class DistributedStrategy(object):
@
property
@
property
def
auto
(
self
):
def
auto
(
self
):
"""
Indicating whether we are using auto-parallel configuration
This feature is currently an experimental feature. Currently,
auto-parallelism can be used only when a user does not set any other
strategy configs except auto. For details, please reference the following
code example
Default Value: False
Examples:
.. code-block:: python
import paddle
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.auto = True
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
"""
return
self
.
strategy
.
auto
return
self
.
strategy
.
auto
@
auto
.
setter
@
auto
.
setter
...
@@ -945,9 +1012,27 @@ class DistributedStrategy(object):
...
@@ -945,9 +1012,27 @@ class DistributedStrategy(object):
@
property
@
property
def
cudnn_exhaustive_search
(
self
):
def
cudnn_exhaustive_search
(
self
):
"""
Indicating whether to use exhaustive search method to choose convolution algorithms.
Exhaustive search attempts all cuDNN algorithms to choose the fastest algorithm.
This method is time-consuming, the choosed algorithm will be cached for the given layer specifications.
Once the layer specifications (like batch size, feature map size) are changed, it will search again.
Default Value: True
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.cudnn_exhaustive_search = False
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
"""
return
self
.
strategy
.
cudnn_exhaustive_search
return
self
.
strategy
.
cudnn_exhaustive_search
@
cudnn_exhaustive_search
.
setter
@
cudnn_exhaustive_search
.
setter
@
is_strict_auto
def
cudnn_exhaustive_search
(
self
,
flag
):
def
cudnn_exhaustive_search
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
cudnn_exhaustive_search
=
flag
self
.
strategy
.
cudnn_exhaustive_search
=
flag
...
@@ -958,9 +1043,28 @@ class DistributedStrategy(object):
...
@@ -958,9 +1043,28 @@ class DistributedStrategy(object):
@
property
@
property
def
conv_workspace_size_limit
(
self
):
def
conv_workspace_size_limit
(
self
):
"""
The workspace limit size in MB unit for choosing cuDNN convolution algorithms.
The inner funciton of cuDNN obtain the fastest suited algorithm that fits within this memory limit.
Usually, large workspace size may lead to choose faster algorithms,
but significant increasing memory workspace. Users need to trade-off between memory and speed.
Default Value: 4000
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.conv_workspace_size_limit = 1024
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
"""
return
self
.
strategy
.
conv_workspace_size_limit
return
self
.
strategy
.
conv_workspace_size_limit
@
conv_workspace_size_limit
.
setter
@
conv_workspace_size_limit
.
setter
@
is_strict_auto
def
conv_workspace_size_limit
(
self
,
value
):
def
conv_workspace_size_limit
(
self
,
value
):
if
isinstance
(
value
,
int
):
if
isinstance
(
value
,
int
):
self
.
strategy
.
conv_workspace_size_limit
=
value
self
.
strategy
.
conv_workspace_size_limit
=
value
...
@@ -971,9 +1075,26 @@ class DistributedStrategy(object):
...
@@ -971,9 +1075,26 @@ class DistributedStrategy(object):
@
property
@
property
def
cudnn_batchnorm_spatial_persistent
(
self
):
def
cudnn_batchnorm_spatial_persistent
(
self
):
"""
Indicates whether to use the mode CUDNN_BATCHNORM_SPATIAL_PERSISTENT function in batchnorm.
This is only useful in cudnn.
Default Value: True
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.cudnn_batchnorm_spatial_persistent = True
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
"""
return
self
.
strategy
.
cudnn_batchnorm_spatial_persistent
return
self
.
strategy
.
cudnn_batchnorm_spatial_persistent
@
cudnn_batchnorm_spatial_persistent
.
setter
@
cudnn_batchnorm_spatial_persistent
.
setter
@
is_strict_auto
def
cudnn_batchnorm_spatial_persistent
(
self
,
flag
):
def
cudnn_batchnorm_spatial_persistent
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
cudnn_batchnorm_spatial_persistent
=
flag
self
.
strategy
.
cudnn_batchnorm_spatial_persistent
=
flag
...
@@ -1005,6 +1126,12 @@ class DistributedStrategy(object):
...
@@ -1005,6 +1126,12 @@ class DistributedStrategy(object):
if
core
.
globals
().
is_public
(
key
):
if
core
.
globals
().
is_public
(
key
):
core
.
globals
()[
key
]
=
values
[
i
]
core
.
globals
()[
key
]
=
values
[
i
]
def
_is_strict_auto
(
self
):
global
non_auto_func_called
if
self
.
strategy
.
auto
and
non_auto_func_called
:
return
True
return
False
def
__repr__
(
self
):
def
__repr__
(
self
):
fields
=
self
.
strategy
.
DESCRIPTOR
.
fields
fields
=
self
.
strategy
.
DESCRIPTOR
.
fields
for
f
in
fields
:
for
f
in
fields
:
...
...
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
0443b480
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
import
copy
import
warnings
import
warnings
import
paddle
import
paddle
from
paddle.fluid.framework
import
dygraph_only
from
paddle.fluid.framework
import
dygraph_only
...
@@ -1008,6 +1009,18 @@ class Fleet(object):
...
@@ -1008,6 +1009,18 @@ class Fleet(object):
MetaOptimizerFactory
().
_get_valid_meta_optimizers
(
MetaOptimizerFactory
().
_get_valid_meta_optimizers
(
self
.
user_defined_optimizer
)
self
.
user_defined_optimizer
)
context
[
"user_defined_strategy"
]
=
copy
.
copy
(
self
.
user_defined_strategy
)
# trigger the auto-parallel in very strict condition
# strategy = DistributedStrategy()
# strategy.auto = True
# optimizer = paddle.optimizer.SGD(learning_rate=0.1)
# optimizer = fleet.distributed_optimizer(optimizer, strategy)
if
self
.
user_defined_strategy
.
_is_strict_auto
():
# turn on all the strategy for each optimizer
for
opt
in
distributed_optimizer_list
:
opt
.
_enable_strategy
(
self
.
user_defined_strategy
)
valid_optimizer_list
=
[]
valid_optimizer_list
=
[]
valid_graph_optimizer_list
=
[]
valid_graph_optimizer_list
=
[]
can_not_apply_optimizer_list
=
[]
can_not_apply_optimizer_list
=
[]
...
...
python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
浏览文件 @
0443b480
...
@@ -42,6 +42,17 @@ class AMPOptimizer(MetaOptimizerBase):
...
@@ -42,6 +42,17 @@ class AMPOptimizer(MetaOptimizerBase):
dist_strategy
.
amp
=
False
dist_strategy
.
amp
=
False
dist_strategy
.
amp_configs
=
{}
dist_strategy
.
amp_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
amp
=
True
dist_strategy
.
amp_configs
=
{
"init_loss_scaling"
:
32768.0
,
"incr_every_n_steps"
:
1000
,
"decr_every_n_nan_or_inf"
:
2
,
"incr_ratio"
:
2.0
,
"decr_ratio"
:
8.0
,
"use_dynamic_loss_scaling"
:
True
}
def
minimize_impl
(
self
,
def
minimize_impl
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
浏览文件 @
0443b480
...
@@ -69,6 +69,10 @@ class DGCOptimizer(MetaOptimizerBase):
...
@@ -69,6 +69,10 @@ class DGCOptimizer(MetaOptimizerBase):
dist_strategy
.
dgc
=
False
dist_strategy
.
dgc
=
False
dist_strategy
.
dgc_configs
=
{}
dist_strategy
.
dgc_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
dgc
=
True
dist_strategy
.
dgc_configs
=
{
"rampup_begin_step"
:
0
,
"rampup_step"
:
1
}
def
backward
(
self
,
def
backward
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py
浏览文件 @
0443b480
...
@@ -45,6 +45,10 @@ class GradientMergeOptimizer(MetaOptimizerBase):
...
@@ -45,6 +45,10 @@ class GradientMergeOptimizer(MetaOptimizerBase):
dist_strategy
.
gradient_merge
=
False
dist_strategy
.
gradient_merge
=
False
dist_strategy
.
gradient_merge_configs
=
{}
dist_strategy
.
gradient_merge_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
# we currently do not support auto-enable gradient merge
return
def
minimize_impl
(
self
,
def
minimize_impl
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py
浏览文件 @
0443b480
...
@@ -148,9 +148,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
...
@@ -148,9 +148,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
sync_allreduce
=
dist_strategy
.
sync_nccl_allreduce
sync_allreduce
=
dist_strategy
.
sync_nccl_allreduce
if
sync_allreduce
:
if
sync_allreduce
:
paddle
.
fluid
.
framework
.
set_flags
({
"FLAGS_sync_nccl_allreduce"
:
True
})
exe_strategy
.
num_threads
=
local_build_strategy
.
nccl_comm_num
+
1
exe_strategy
.
num_threads
=
local_build_strategy
.
nccl_comm_num
+
1
if
local_build_strategy
.
use_hierarchical_allreduce
:
if
local_build_strategy
.
use_hierarchical_allreduce
:
exe_strategy
.
num_threads
=
2
*
local_build_strategy
.
nccl_comm_num
+
1
exe_strategy
.
num_threads
=
2
*
local_build_strategy
.
nccl_comm_num
+
1
...
@@ -191,7 +188,11 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
...
@@ -191,7 +188,11 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
def
_disable_strategy
(
self
,
dist_strategy
):
def
_disable_strategy
(
self
,
dist_strategy
):
# TODO(guru4elephant): should close all PE related flags here
# TODO(guru4elephant): should close all PE related flags here
pass
return
def
_enable_strategy
(
self
,
dist_strategy
):
# by default, graph execution strategy is enabled
return
def
minimize
(
self
,
def
minimize
(
self
,
loss
,
loss
,
...
...
python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py
浏览文件 @
0443b480
...
@@ -75,6 +75,13 @@ class LambOptimizer(MetaOptimizerBase):
...
@@ -75,6 +75,13 @@ class LambOptimizer(MetaOptimizerBase):
dist_strategy
.
lamb
=
False
dist_strategy
.
lamb
=
False
dist_strategy
.
lamb_configs
=
{}
dist_strategy
.
lamb_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
lamb
=
True
dist_strategy
.
lamb_configs
=
{
"lamb_weight_decay"
:
0.01
,
"exclude_from_weight_decay"
:
[]
}
def
backward
(
self
,
def
backward
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py
浏览文件 @
0443b480
...
@@ -59,6 +59,13 @@ class LarsOptimizer(MetaOptimizerBase):
...
@@ -59,6 +59,13 @@ class LarsOptimizer(MetaOptimizerBase):
dist_strategy
.
lars
=
False
dist_strategy
.
lars
=
False
dist_strategy
.
lars_configs
=
{}
dist_strategy
.
lars_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
lars
=
True
dist_strategy
.
lars_configs
=
{
"lars_coeff"
:
0.01
,
"lars_weight_decay"
:
0.0005
,
}
def
backward
(
self
,
def
backward
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py
浏览文件 @
0443b480
...
@@ -42,6 +42,10 @@ class LocalSGDOptimizer(MetaOptimizerBase):
...
@@ -42,6 +42,10 @@ class LocalSGDOptimizer(MetaOptimizerBase):
dist_strategy
.
localsgd
=
False
dist_strategy
.
localsgd
=
False
dist_strategy
.
localsgd_configs
=
{}
dist_strategy
.
localsgd_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
localsgd
=
True
dist_strategy
.
localsgd_configs
=
{
"k_steps"
:
1
}
def
snapshot_name
(
self
,
param_name
):
def
snapshot_name
(
self
,
param_name
):
return
param_name
+
self
.
snapshot_key
return
param_name
+
self
.
snapshot_key
...
...
python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py
浏览文件 @
0443b480
...
@@ -48,6 +48,10 @@ class MetaOptimizerBase(Optimizer):
...
@@ -48,6 +48,10 @@ class MetaOptimizerBase(Optimizer):
raise
NotImplementedError
(
"you should implement disable strategy in {}"
.
raise
NotImplementedError
(
"you should implement disable strategy in {}"
.
format
(
type
(
self
).
__name__
))
format
(
type
(
self
).
__name__
))
def
_enable_strategy
(
self
,
dist_strategy
):
raise
NotImplementedError
(
"you should implement enable strategy in {}"
.
format
(
type
(
self
).
__name__
))
def
apply_gradients
(
self
,
params_grads
):
def
apply_gradients
(
self
,
params_grads
):
return
self
.
inner_opt
.
apply_gradients
(
params_grads
=
params_grads
)
return
self
.
inner_opt
.
apply_gradients
(
params_grads
=
params_grads
)
...
...
python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py
浏览文件 @
0443b480
...
@@ -39,6 +39,11 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer):
...
@@ -39,6 +39,11 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer):
def
_disable_strategy
(
self
,
dist_strategy
):
def
_disable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
a_sync_configs
=
{}
dist_strategy
.
a_sync_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
# only open up the async mode for auto-parallel
dist_strategy
.
a_sync
=
True
dist_strategy
.
a_sync_configs
=
{}
def
_is_graph_out
(
self
):
def
_is_graph_out
(
self
):
return
True
return
True
...
...
python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py
浏览文件 @
0443b480
...
@@ -157,4 +157,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
...
@@ -157,4 +157,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
return
None
,
None
return
None
,
None
def
_disable_strategy
(
self
,
dist_strategy
):
def
_disable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
a_sync_configs
=
{}
self
.
user_defined_strategy
.
a_sync_configs
=
{}
self
.
user_defined_strategy
.
a_sync_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
a_sync
=
True
dist_strategy
.
a_sync_configs
=
{}
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
浏览文件 @
0443b480
...
@@ -111,6 +111,10 @@ class PipelineOptimizer(MetaOptimizerBase):
...
@@ -111,6 +111,10 @@ class PipelineOptimizer(MetaOptimizerBase):
dist_strategy
.
pipeline
=
False
dist_strategy
.
pipeline
=
False
dist_strategy
.
pipeline_configs
=
{}
dist_strategy
.
pipeline_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
# we do not support enable pipeline automatically right now
return
def
minimize_impl
(
self
,
def
minimize_impl
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py
浏览文件 @
0443b480
...
@@ -49,6 +49,10 @@ class RecomputeOptimizer(MetaOptimizerBase):
...
@@ -49,6 +49,10 @@ class RecomputeOptimizer(MetaOptimizerBase):
dist_strategy
.
recompute
=
False
dist_strategy
.
recompute
=
False
dist_strategy
.
recompute_configs
=
{}
dist_strategy
.
recompute_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
):
# we do not support automatically recompute checkpoints currently
return
def
backward
(
self
,
def
backward
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
0443b480
...
@@ -47,6 +47,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_dgc_meta_optimizer)
...
@@ -47,6 +47,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_dgc_meta_optimizer)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_private_function
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_private_function
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_auto
)
foreach
(
TEST_OP
${
MIXED_DIST_TEST_OPS
}
)
foreach
(
TEST_OP
${
MIXED_DIST_TEST_OPS
}
)
list
(
REMOVE_ITEM TEST_OPS
${
TEST_OP
}
)
list
(
REMOVE_ITEM TEST_OPS
${
TEST_OP
}
)
endforeach
()
endforeach
()
...
@@ -458,6 +459,7 @@ if(WITH_DISTRIBUTE)
...
@@ -458,6 +459,7 @@ if(WITH_DISTRIBUTE)
py_test_modules
(
test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_private_function MODULES test_fleet_private_function ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_private_function MODULES test_fleet_private_function ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_auto MODULES test_fleet_auto ENVS
${
dist_ENVS
}
)
if
(
NOT WIN32
)
if
(
NOT WIN32
)
py_test_modules
(
test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS
${
dist_ENVS
}
)
...
...
python/paddle/fluid/tests/unittests/test_fleet_auto.py
0 → 100644
浏览文件 @
0443b480
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
os
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet.base.role_maker
as
role_maker
class
TestDistributedStrategyAuto
(
unittest
.
TestCase
):
def
setUp
(
self
):
os
.
environ
[
"POD_IP"
]
=
"127.0.0.1"
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001"
os
.
environ
[
"PADDLE_TRAINERS_NUM"
]
=
"2"
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
\
"127.0.0.1:36001,127.0.0.2:36001"
def
test_distributed_strategy_auto
(
self
):
fleet
.
init
(
is_collective
=
True
)
input_x
=
paddle
.
fluid
.
layers
.
data
(
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
input_y
=
paddle
.
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
fc_1
=
paddle
.
fluid
.
layers
.
fc
(
input
=
input_x
,
size
=
64
,
act
=
'tanh'
)
fc_2
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
64
,
act
=
'tanh'
)
prediction
=
paddle
.
fluid
.
layers
.
fc
(
input
=
[
fc_2
],
size
=
2
,
act
=
'softmax'
)
cost
=
paddle
.
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
input_y
)
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
x
=
cost
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
strategy
.
auto
=
True
optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.01
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录