Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0a1862d1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0a1862d1
编写于
10月 12, 2020
作者:
W
WangXi
提交者:
GitHub
10月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fleet combine amp dgc recompute meta optimizer (#27643)
上级
8fabb1c3
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
644 addition
and
255 deletion
+644
-255
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+2
-2
python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
...paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
+39
-18
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
...paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
+7
-0
python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py
...addle/distributed/fleet/meta_optimizers/lamb_optimizer.py
+4
-0
python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py
...addle/distributed/fleet/meta_optimizers/lars_optimizer.py
+4
-0
python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py
...e/distributed/fleet/meta_optimizers/localsgd_optimizer.py
+2
-2
python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py
.../distributed/fleet/meta_optimizers/recompute_optimizer.py
+21
-5
python/paddle/fluid/contrib/mixed_precision/decorator.py
python/paddle/fluid/contrib/mixed_precision/decorator.py
+60
-40
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+5
-10
python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py
...paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py
+122
-0
python/paddle/fluid/tests/unittests/test_dgc_optimizer.py
python/paddle/fluid/tests/unittests/test_dgc_optimizer.py
+13
-4
python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py
...le/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py
+75
-35
python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py
...le/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py
+83
-52
python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py
...uid/tests/unittests/test_fleet_localsgd_meta_optimizer.py
+74
-58
python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py
...id/tests/unittests/test_fleet_recompute_meta_optimizer.py
+133
-29
未找到文件。
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
0a1862d1
...
@@ -744,13 +744,13 @@ class DistributedStrategy(object):
...
@@ -744,13 +744,13 @@ class DistributedStrategy(object):
strategy.adaptive_localsgd = True # by default this is false
strategy.adaptive_localsgd = True # by default this is false
"""
"""
return
self
.
strategy
.
localsgd
return
self
.
strategy
.
adaptive_
localsgd
@
adaptive_localsgd
.
setter
@
adaptive_localsgd
.
setter
@
is_strict_auto
@
is_strict_auto
def
adaptive_localsgd
(
self
,
flag
):
def
adaptive_localsgd
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
localsgd
=
flag
self
.
strategy
.
adaptive_
localsgd
=
flag
else
:
else
:
print
(
"WARNING: adaptive_localsgd should have value of bool type"
)
print
(
"WARNING: adaptive_localsgd should have value of bool type"
)
...
...
python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
浏览文件 @
0a1862d1
...
@@ -19,16 +19,14 @@ class AMPOptimizer(MetaOptimizerBase):
...
@@ -19,16 +19,14 @@ class AMPOptimizer(MetaOptimizerBase):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
):
super
(
AMPOptimizer
,
self
).
__init__
(
optimizer
)
super
(
AMPOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
inner_opt
=
optimizer
self
.
amp
_opt
=
None
self
.
wrapped
_opt
=
None
# we do not allow meta optimizer to be inner optimizer currently
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[
self
.
meta_optimizers_white_list
=
[
"LarsOptimizer"
,
"LarsOptimizer"
,
"LambOptimizer"
,
"LambOptimizer"
,
"RecomputeOptimizer"
,
"RecomputeOptimizer"
,
"LocalSGDOptimizer"
,
"GradientMergeOptimizer"
,
"GradientMergeOptimizer"
,
"GraphExecutionOptimizer"
,
"GraphExecutionOptimizer"
,
"AdaptiveLocalSGDOptimizer"
,
]
]
self
.
meta_optimizers_black_list
=
[
"DGCOptimizer"
]
self
.
meta_optimizers_black_list
=
[
"DGCOptimizer"
]
...
@@ -37,6 +35,24 @@ class AMPOptimizer(MetaOptimizerBase):
...
@@ -37,6 +35,24 @@ class AMPOptimizer(MetaOptimizerBase):
super
(
AMPOptimizer
,
self
).
_set_basic_info
(
super
(
AMPOptimizer
,
self
).
_set_basic_info
(
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
)
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
)
def
_init_wrapped_opt
(
self
):
if
self
.
wrapped_opt
is
not
None
:
return
config
=
self
.
user_defined_strategy
.
amp_configs
custom_white_list
=
set
(
config
[
'custom_white_list'
])
custom_black_list
=
set
(
config
[
'custom_black_list'
])
custom_black_varnames
=
set
(
config
[
'custom_black_varnames'
])
amp_lists
=
mixed_precision
.
AutoMixedPrecisionLists
(
custom_white_list
,
custom_black_list
,
custom_black_varnames
)
self
.
wrapped_opt
=
mixed_precision
.
decorate
(
self
.
inner_opt
,
amp_lists
,
config
[
'init_loss_scaling'
],
config
[
'incr_every_n_steps'
],
config
[
'decr_every_n_nan_or_inf'
],
config
[
'incr_ratio'
],
config
[
'decr_ratio'
],
config
[
'use_dynamic_loss_scaling'
])
def
_can_apply
(
self
):
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
if
not
self
.
role_maker
.
_is_collective
:
return
False
return
False
...
@@ -60,26 +76,31 @@ class AMPOptimizer(MetaOptimizerBase):
...
@@ -60,26 +76,31 @@ class AMPOptimizer(MetaOptimizerBase):
"use_dynamic_loss_scaling"
:
True
"use_dynamic_loss_scaling"
:
True
}
}
def
backward
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
,
callbacks
=
None
):
# maybe inner_opt of other meta optimizer
self
.
_init_wrapped_opt
()
return
self
.
wrapped_opt
.
backward
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
,
callbacks
)
def
apply_gradients
(
self
,
params_grads
):
return
self
.
wrapped_opt
.
apply_gradients
(
params_grads
=
params_grads
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
return
self
.
wrapped_opt
.
apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
def
minimize_impl
(
self
,
def
minimize_impl
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
parameter_list
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
no_grad_set
=
None
):
if
self
.
amp_opt
is
None
:
self
.
_init_wrapped_opt
()
config
=
self
.
user_defined_strategy
.
amp_configs
custom_white_list
=
set
(
config
[
'custom_white_list'
])
custom_black_list
=
set
(
config
[
'custom_black_list'
])
custom_black_varnames
=
set
(
config
[
'custom_black_varnames'
])
amp_lists
=
mixed_precision
.
AutoMixedPrecisionLists
(
custom_white_list
,
custom_black_list
,
custom_black_varnames
)
self
.
amp_opt
=
mixed_precision
.
decorate
(
self
.
inner_opt
,
amp_lists
,
config
[
'init_loss_scaling'
],
config
[
'incr_every_n_steps'
],
config
[
'decr_every_n_nan_or_inf'
],
config
[
'incr_ratio'
],
config
[
'decr_ratio'
],
config
[
'use_dynamic_loss_scaling'
])
optimize_ops
,
params_grads
=
\
optimize_ops
,
params_grads
=
\
self
.
amp
_opt
.
minimize
(
loss
,
startup_program
,
self
.
wrapped
_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
parameter_list
,
no_grad_set
)
return
optimize_ops
,
params_grads
return
optimize_ops
,
params_grads
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
浏览文件 @
0a1862d1
...
@@ -85,6 +85,13 @@ class DGCOptimizer(MetaOptimizerBase):
...
@@ -85,6 +85,13 @@ class DGCOptimizer(MetaOptimizerBase):
return
self
.
dgc_opt
.
backward
(
loss
,
startup_program
,
parameter_list
,
return
self
.
dgc_opt
.
backward
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
,
callbacks
)
no_grad_set
,
callbacks
)
def
apply_gradients
(
self
,
params_grads
):
return
self
.
dgc_opt
.
apply_gradients
(
params_grads
=
params_grads
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
return
self
.
dgc_opt
.
apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
def
minimize_impl
(
self
,
def
minimize_impl
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py
浏览文件 @
0a1862d1
...
@@ -98,6 +98,10 @@ class LambOptimizer(MetaOptimizerBase):
...
@@ -98,6 +98,10 @@ class LambOptimizer(MetaOptimizerBase):
def
apply_gradients
(
self
,
params_grads
):
def
apply_gradients
(
self
,
params_grads
):
return
self
.
lamb_opt
.
apply_gradients
(
params_grads
=
params_grads
)
return
self
.
lamb_opt
.
apply_gradients
(
params_grads
=
params_grads
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
return
self
.
lamb_opt
.
apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
def
minimize_impl
(
self
,
def
minimize_impl
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py
浏览文件 @
0a1862d1
...
@@ -85,6 +85,10 @@ class LarsOptimizer(MetaOptimizerBase):
...
@@ -85,6 +85,10 @@ class LarsOptimizer(MetaOptimizerBase):
def
apply_gradients
(
self
,
params_grads
):
def
apply_gradients
(
self
,
params_grads
):
return
self
.
lars_opt
.
apply_gradients
(
params_grads
=
params_grads
)
return
self
.
lars_opt
.
apply_gradients
(
params_grads
=
params_grads
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
return
self
.
lars_opt
.
apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
def
minimize_impl
(
self
,
def
minimize_impl
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
...
python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py
浏览文件 @
0a1862d1
...
@@ -24,7 +24,7 @@ class LocalSGDOptimizer(MetaOptimizerBase):
...
@@ -24,7 +24,7 @@ class LocalSGDOptimizer(MetaOptimizerBase):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
):
super
(
LocalSGDOptimizer
,
self
).
__init__
(
optimizer
)
super
(
LocalSGDOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
inner_opt
=
optimizer
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[
'AMPOptimizer'
]
self
.
meta_optimizers_black_list
=
[
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
"GraphExecutionOptimizer"
,
"AdaptiveLocalSGDOptimizer"
,
"AdaptiveLocalSGDOptimizer"
,
...
@@ -195,7 +195,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase):
...
@@ -195,7 +195,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
):
super
(
AdaptiveLocalSGDOptimizer
,
self
).
__init__
(
optimizer
)
super
(
AdaptiveLocalSGDOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
inner_opt
=
optimizer
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[
'AMPOptimizer'
]
self
.
meta_optimizers_black_list
=
[
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
"LocalSGDOptimizer"
"GraphExecutionOptimizer"
,
"LocalSGDOptimizer"
]
]
...
...
python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py
浏览文件 @
0a1862d1
...
@@ -18,15 +18,14 @@ from .meta_optimizer_base import MetaOptimizerBase
...
@@ -18,15 +18,14 @@ from .meta_optimizer_base import MetaOptimizerBase
class
RecomputeOptimizer
(
MetaOptimizerBase
):
class
RecomputeOptimizer
(
MetaOptimizerBase
):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
):
super
(
RecomputeOptimizer
,
self
).
__init__
(
optimizer
)
super
(
RecomputeOptimizer
,
self
).
__init__
(
optimizer
)
#self.inner_opt = RO(optimizer)
self
.
inner_opt
=
optimizer
self
.
inner_opt
=
optimizer
self
.
wrapped_opt
=
RO
(
optimizer
)
self
.
wrapped_opt
=
None
# we do not allow meta optimizer to be inner optimizer currently
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[
self
.
meta_optimizers_white_list
=
[
"LarsOptimizer"
,
"LarsOptimizer"
,
"LambOptimizer"
,
"LambOptimizer"
,
"GradientMergeOptimizer"
,
"GraphExecutionOptimizer"
,
"GraphExecutionOptimizer"
,
"DGCOptimizer"
,
]
]
self
.
meta_optimizers_black_list
=
[]
self
.
meta_optimizers_black_list
=
[]
...
@@ -34,8 +33,15 @@ class RecomputeOptimizer(MetaOptimizerBase):
...
@@ -34,8 +33,15 @@ class RecomputeOptimizer(MetaOptimizerBase):
user_defined_strategy
):
user_defined_strategy
):
super
(
RecomputeOptimizer
,
self
).
_set_basic_info
(
super
(
RecomputeOptimizer
,
self
).
_set_basic_info
(
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
)
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
)
self
.
wrapped_opt
.
_set_checkpoints
(
list
(
user_defined_strategy
.
recompute_configs
[
"checkpoints"
]))
def
_init_wrapped_opt
(
self
):
if
self
.
wrapped_opt
is
not
None
:
return
configs
=
self
.
user_defined_strategy
.
recompute_configs
self
.
wrapped_opt
=
RO
(
self
.
inner_opt
)
self
.
wrapped_opt
.
_set_checkpoints
(
list
(
configs
[
"checkpoints"
]))
def
_can_apply
(
self
):
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
if
not
self
.
role_maker
.
_is_collective
:
...
@@ -62,14 +68,24 @@ class RecomputeOptimizer(MetaOptimizerBase):
...
@@ -62,14 +68,24 @@ class RecomputeOptimizer(MetaOptimizerBase):
parameter_list
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
,
no_grad_set
=
None
,
callbacks
=
None
):
callbacks
=
None
):
# maybe inner_opt of other meta optimizer
self
.
_init_wrapped_opt
()
return
self
.
wrapped_opt
.
backward
(
loss
,
startup_program
,
parameter_list
,
return
self
.
wrapped_opt
.
backward
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
,
callbacks
)
no_grad_set
,
callbacks
)
def
apply_gradients
(
self
,
params_grads
):
return
self
.
wrapped_opt
.
apply_gradients
(
params_grads
=
params_grads
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
return
self
.
wrapped_opt
.
apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
def
minimize_impl
(
self
,
def
minimize_impl
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
parameter_list
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
no_grad_set
=
None
):
self
.
_init_wrapped_opt
()
optimize_ops
,
params_grads
=
\
optimize_ops
,
params_grads
=
\
self
.
wrapped_opt
.
minimize
(
loss
,
startup_program
,
self
.
wrapped_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
parameter_list
,
no_grad_set
)
...
...
python/paddle/fluid/contrib/mixed_precision/decorator.py
浏览文件 @
0a1862d1
...
@@ -16,6 +16,7 @@ from ... import default_main_program
...
@@ -16,6 +16,7 @@ from ... import default_main_program
from
...
import
default_startup_program
from
...
import
default_startup_program
from
...
import
layers
from
...
import
layers
from
...
import
unique_name
from
...
import
unique_name
from
...
import
program_guard
from
.
import
fp16_utils
from
.
import
fp16_utils
from
.fp16_utils
import
rewrite_program
from
.fp16_utils
import
rewrite_program
from
.fp16_utils
import
update_role_var_grad
from
.fp16_utils
import
update_role_var_grad
...
@@ -58,21 +59,40 @@ class OptimizerWithMixedPrecision(object):
...
@@ -58,21 +59,40 @@ class OptimizerWithMixedPrecision(object):
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
_amp_lists
=
amp_lists
self
.
_amp_lists
=
amp_lists
self
.
_param_grads
=
None
self
.
_param_grads
=
None
self
.
_train_program
=
default_main_program
()
self
.
_train_program
=
None
self
.
_startup_prog
=
default_startup_program
()
self
.
_scaled_loss
=
None
self
.
_scaled_loss
=
None
self
.
_loss_scaling
=
layers
.
create_global_var
(
self
.
_loss_scaling
=
None
name
=
unique_name
.
generate
(
"loss_scaling"
),
self
.
_init_loss_scaling
=
init_loss_scaling
shape
=
[
1
],
value
=
init_loss_scaling
,
dtype
=
'float32'
,
persistable
=
True
)
self
.
_use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
self
.
_use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
if
self
.
_use_dynamic_loss_scaling
:
if
self
.
_use_dynamic_loss_scaling
:
self
.
_incr_every_n_steps
=
incr_every_n_steps
self
.
_incr_every_n_steps
=
incr_every_n_steps
self
.
_decr_every_n_nan_or_inf
=
decr_every_n_nan_or_inf
self
.
_decr_every_n_nan_or_inf
=
decr_every_n_nan_or_inf
self
.
_incr_ratio
=
incr_ratio
self
.
_incr_ratio
=
incr_ratio
self
.
_decr_ratio
=
decr_ratio
self
.
_decr_ratio
=
decr_ratio
self
.
_num_good_steps
=
None
self
.
_num_bad_steps
=
None
def
get_loss_scaling
(
self
):
"""Return the real-time loss scaling factor.
"""
return
self
.
_loss_scaling
def
get_scaled_loss
(
self
):
"""Return the scaled loss.
It's useful when you feed customed loss into executor.
"""
return
self
.
_scaled_loss
def
_init_amp_var
(
self
):
self
.
_loss_scaling
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"loss_scaling"
),
shape
=
[
1
],
value
=
self
.
_init_loss_scaling
,
dtype
=
'float32'
,
persistable
=
True
)
if
self
.
_use_dynamic_loss_scaling
:
self
.
_num_good_steps
=
layers
.
create_global_var
(
self
.
_num_good_steps
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"num_good_steps"
),
name
=
unique_name
.
generate
(
"num_good_steps"
),
shape
=
[
1
],
shape
=
[
1
],
...
@@ -86,28 +106,16 @@ class OptimizerWithMixedPrecision(object):
...
@@ -86,28 +106,16 @@ class OptimizerWithMixedPrecision(object):
dtype
=
'int32'
,
dtype
=
'int32'
,
persistable
=
True
)
persistable
=
True
)
# Ensure the data type of learning rate vars is float32 (same as the
# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
# master parameter dtype)
if
isinstance
(
optimizer
.
_learning_rate
,
float
):
if
isinstance
(
self
.
_optimizer
.
_learning_rate
,
float
):
optimizer
.
_learning_rate_map
[
default_main_program
()]
=
\
self
.
_optimizer
.
_learning_rate_map
[
default_main_program
()]
=
\
layers
.
create_global_var
(
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"learning_rate"
),
name
=
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
shape
=
[
1
],
value
=
float
(
optimizer
.
_learning_rate
),
value
=
float
(
self
.
_optimizer
.
_learning_rate
),
dtype
=
'float32'
,
dtype
=
'float32'
,
persistable
=
True
)
persistable
=
True
)
def
get_loss_scaling
(
self
):
"""Return the real-time loss scaling factor.
"""
return
self
.
_loss_scaling
def
get_scaled_loss
(
self
):
"""Return the scaled loss.
It's useful when you feed customed loss into executor.
"""
return
self
.
_scaled_loss
def
backward
(
self
,
def
backward
(
self
,
loss
,
loss
,
...
@@ -131,16 +139,21 @@ class OptimizerWithMixedPrecision(object):
...
@@ -131,16 +139,21 @@ class OptimizerWithMixedPrecision(object):
A list of (param, grad), which is a tuple of a parameter and its
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
gradient respectively, and the scaled loss.
"""
"""
rewrite_program
(
self
.
_train_program
,
self
.
_amp_lists
)
train_program
=
loss
.
block
.
program
self
.
_scaled_loss
=
loss
*
self
.
_loss_scaling
self
.
_train_program
=
train_program
self
.
_params_grads
=
self
.
_optimizer
.
backward
(
self
.
_scaled_loss
,
startup_program
,
parameter_list
,
no_grad_set
,
with
program_guard
(
train_program
,
startup_program
):
callbacks
)
self
.
_init_amp_var
()
# Change the op_role_var attr for some ops, so that gradients
# transferred across GPUs can be FP16.
rewrite_program
(
train_program
,
self
.
_amp_lists
)
update_role_var_grad
(
self
.
_train_program
,
self
.
_params_grads
)
self
.
_scaled_loss
=
loss
*
self
.
_loss_scaling
params_grads
=
self
.
_optimizer
.
backward
(
return
self
.
_params_grads
self
.
_scaled_loss
,
startup_program
,
parameter_list
,
no_grad_set
,
callbacks
)
# Change the op_role_var attr for some ops, so that gradients
# transferred across GPUs can be FP16.
update_role_var_grad
(
train_program
,
params_grads
)
return
params_grads
def
apply_gradients
(
self
,
params_grads
):
def
apply_gradients
(
self
,
params_grads
):
"""
"""
...
@@ -182,6 +195,12 @@ class OptimizerWithMixedPrecision(object):
...
@@ -182,6 +195,12 @@ class OptimizerWithMixedPrecision(object):
return
optimize_ops
return
optimize_ops
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
program
=
loss
.
block
.
program
with
program_guard
(
program
,
startup_program
):
optimize_ops
=
self
.
apply_gradients
(
params_grads
)
return
optimize_ops
def
minimize
(
self
,
def
minimize
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
@@ -207,7 +226,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -207,7 +226,8 @@ class OptimizerWithMixedPrecision(object):
parameter_list
=
parameter_list
,
parameter_list
=
parameter_list
,
no_grad_set
=
no_grad_set
)
no_grad_set
=
no_grad_set
)
optimize_ops
=
self
.
apply_gradients
(
scaled_params_grads
)
optimize_ops
=
self
.
apply_optimize
(
loss
,
startup_program
,
scaled_params_grads
)
return
optimize_ops
,
scaled_params_grads
return
optimize_ops
,
scaled_params_grads
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
0a1862d1
...
@@ -731,9 +731,6 @@ class Optimizer(object):
...
@@ -731,9 +731,6 @@ class Optimizer(object):
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
]})
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
]})
return
new_param_grads
,
(
table_param
,
table_grad
),
sgd_op
return
new_param_grads
,
(
table_param
,
table_grad
),
sgd_op
def
_append_dgc_ops
(
self
,
param_and_grad
):
pass
def
backward
(
self
,
def
backward
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
...
@@ -801,9 +798,6 @@ class Optimizer(object):
...
@@ -801,9 +798,6 @@ class Optimizer(object):
with
program_guard
(
program
,
startup_program
):
with
program_guard
(
program
,
startup_program
):
params_grads
=
append_backward
(
loss
,
parameter_list
,
params_grads
=
append_backward
(
loss
,
parameter_list
,
act_no_grad_set
,
callbacks
)
act_no_grad_set
,
callbacks
)
# Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad.
self
.
_append_dgc_ops
(
params_grads
)
return
params_grads
return
params_grads
def
apply_gradients
(
self
,
params_grads
):
def
apply_gradients
(
self
,
params_grads
):
...
@@ -1569,6 +1563,11 @@ class DGCMomentumOptimizer(Optimizer):
...
@@ -1569,6 +1563,11 @@ class DGCMomentumOptimizer(Optimizer):
@
imperative_base
.
no_grad
@
imperative_base
.
no_grad
def
apply_gradients
(
self
,
params_grads
):
def
apply_gradients
(
self
,
params_grads
):
# Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad.
# Maybe need a grad allreduce pass.
self
.
_append_dgc_ops
(
params_grads
)
params_grads
=
sorted
(
params_grads
,
key
=
lambda
x
:
x
[
0
].
name
)
params_grads
=
sorted
(
params_grads
,
key
=
lambda
x
:
x
[
0
].
name
)
params_grads
,
table_param_and_grad
,
table_optimize_op
=
\
params_grads
,
table_param_and_grad
,
table_optimize_op
=
\
self
.
_process_distribute_lookuptable
(
params_grads
)
self
.
_process_distribute_lookuptable
(
params_grads
)
...
@@ -4784,10 +4783,6 @@ class RecomputeOptimizer(Optimizer):
...
@@ -4784,10 +4783,6 @@ class RecomputeOptimizer(Optimizer):
params_grads
=
append_backward
(
params_grads
=
append_backward
(
loss
,
parameter_list
,
no_grad_set
,
checkpoints
=
checkpoint_vars
)
loss
,
parameter_list
,
no_grad_set
,
checkpoints
=
checkpoint_vars
)
# Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad.
if
hasattr
(
self
.
_optimizer
,
"_append_dgc_ops"
):
self
.
_optimizer
.
_append_dgc_ops
(
params_grads
)
return
params_grads
return
params_grads
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
...
...
python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py
0 → 100755
浏览文件 @
0a1862d1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
from
paddle
import
fluid
import
os
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet.base.role_maker
as
role_maker
class
TestFleetMetaOptimizer
(
unittest
.
TestCase
):
def
setUp
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"1"
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001,127.0.0.1:36002"
def
net
(
self
,
main_prog
,
startup_prog
):
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
fleet
.
init
(
role
)
input_x
=
paddle
.
fluid
.
layers
.
data
(
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
input_y
=
paddle
.
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
fc_1
=
paddle
.
fluid
.
layers
.
fc
(
input
=
input_x
,
size
=
64
,
act
=
'tanh'
)
fc_2
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
256
,
act
=
'tanh'
)
prediction
=
paddle
.
fluid
.
layers
.
fc
(
input
=
[
fc_2
],
size
=
2
,
act
=
'softmax'
)
cost
=
paddle
.
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
input_y
)
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
x
=
cost
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
return
avg_cost
,
strategy
def
optimizer
(
self
,
loss
,
strategy
,
train_prog
,
startup_prog
,
name
=
'momentum'
):
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
if
name
==
'momentum'
:
optimizer
=
paddle
.
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
)
elif
name
==
'adam'
:
optimizer
=
paddle
.
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.01
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
loss
)
def
set_strategy
(
self
,
strategy
,
name
):
if
name
==
'amp'
:
strategy
.
amp
=
True
strategy
.
amp_configs
=
{
"init_loss_scaling"
:
32768
,
"decr_every_n_nan_or_inf"
:
2
,
"incr_every_n_steps"
:
1000
,
"incr_ratio"
:
2.0
,
"use_dynamic_loss_scaling"
:
True
,
"decr_ratio"
:
0.5
,
"custom_white_list"
:
[
'softmax'
],
"custom_black_list"
:
[
'tanh'
],
}
elif
name
==
'dgc'
:
strategy
.
dgc
=
True
strategy
.
dgc_configs
=
{
"rampup_begin_step"
:
128
,
"rampup_step"
:
100
,
"sparsity"
:
[
0.996
,
0.999
]
}
elif
name
==
'recompute'
:
strategy
.
recompute
=
True
strategy
.
recompute_configs
=
{
"checkpoints"
:
[
"fc_0.tmp_2"
,
"fc_1.tmp_2"
]
}
elif
name
==
'lars'
:
strategy
.
lars
=
True
strategy
.
lars_configs
=
{
"lars_coeff"
:
0.001
,
"lars_weight_decay"
:
0.0005
,
"epsilon"
:
0
,
"exclude_from_weight_decay"
:
[
"batch_norm"
,
".b"
],
}
elif
name
==
'lamb'
:
strategy
.
lamb
=
True
strategy
.
lamb_configs
=
{
'lamb_weight_decay'
:
0.01
,
'exclude_from_weight_decay'
:
[],
}
elif
name
==
'localsgd'
:
strategy
.
localsgd
=
True
strategy
.
localsgd_configs
=
{
'k_steps'
:
1
,
'begin_step'
:
1
,
}
elif
name
==
'adaptive_localsgd'
:
strategy
.
adaptive_localsgd
=
True
strategy
.
adaptive_localsgd_configs
=
{
'init_k_steps'
:
1
,
'begin_step'
:
1
,
}
else
:
raise
NotImplementedError
()
python/paddle/fluid/tests/unittests/test_dgc_optimizer.py
浏览文件 @
0a1862d1
...
@@ -16,12 +16,14 @@ from __future__ import print_function
...
@@ -16,12 +16,14 @@ from __future__ import print_function
import
unittest
import
unittest
import
paddle
import
paddle.fluid.framework
as
framework
import
paddle.fluid.framework
as
framework
import
paddle.fluid.optimizer
as
optimizer
import
paddle.fluid.optimizer
as
optimizer
import
paddle.fluid.regularizer
as
regularizer
import
paddle.fluid.regularizer
as
regularizer
import
paddle.fluid.clip
as
clip
import
paddle.fluid.clip
as
clip
import
paddle.compat
as
cpt
import
paddle.compat
as
cpt
from
paddle.fluid.backward
import
append_backward
from
paddle.fluid.backward
import
append_backward
paddle
.
enable_static
()
class
TestDGCMomentumOptimizer
(
unittest
.
TestCase
):
class
TestDGCMomentumOptimizer
(
unittest
.
TestCase
):
...
@@ -86,13 +88,17 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
...
@@ -86,13 +88,17 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
block
.
append_op
(
block
.
append_op
(
type
=
"mean"
,
inputs
=
{
"X"
:
mul_out
},
outputs
=
{
"Out"
:
mean_out
})
type
=
"mean"
,
inputs
=
{
"X"
:
mul_out
},
outputs
=
{
"Out"
:
mean_out
})
# params_grads = append_backward(mean_out)
# params_grads = append_backward(mean_out)
params_grads
=
dgc_momentum_optimizer
.
backward
(
mean_out
)
params_grads
=
dgc_momentum_optimizer
.
backward
(
mean_out
,
startup_program
=
init_program
)
with
framework
.
program_guard
(
program
,
init_program
):
opts
=
dgc_momentum_optimizer
.
apply_gradients
(
params_grads
)
accumulator_count
=
1
if
name
==
"momentum"
else
2
accumulator_count
=
1
if
name
==
"momentum"
else
2
self
.
assertEqual
(
len
(
params_grads
),
1
)
self
.
assertEqual
(
len
(
params_grads
),
1
)
self
.
assertEqual
(
self
.
assertEqual
(
len
(
dgc_momentum_optimizer
.
get_accumulators
()),
accumulator_count
)
len
(
dgc_momentum_optimizer
.
get_accumulators
()),
accumulator_count
)
with
framework
.
program_guard
(
program
,
init_program
):
opts
=
dgc_momentum_optimizer
.
apply_gradients
(
params_grads
)
self
.
assertEqual
(
len
(
opts
),
2
)
self
.
assertEqual
(
len
(
opts
),
2
)
sgd_op
=
opts
[
-
1
]
sgd_op
=
opts
[
-
1
]
self
.
assertEqual
([
op
.
type
for
op
in
opts
],
[
"scale"
,
name
])
self
.
assertEqual
([
op
.
type
for
op
in
opts
],
[
"scale"
,
name
])
...
@@ -108,8 +114,11 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
...
@@ -108,8 +114,11 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
self
.
assertTrue
(
mul_x
.
name
in
velocity_acc
)
self
.
assertTrue
(
mul_x
.
name
in
velocity_acc
)
# Check init_program
# Check init_program
# dgc not apply include: lr, dgc(count, nranks, begin step), (u,)
# dgc apply include: lr, dgc(count, nranks, begin_step), (u,v,k,encode,gather)
init_ops_count
=
5
if
name
==
"momentum"
else
9
init_ops
=
init_program
.
global_block
().
ops
init_ops
=
init_program
.
global_block
().
ops
self
.
assertEqual
(
len
(
init_ops
),
1
)
self
.
assertEqual
(
len
(
init_ops
),
init_ops_count
)
self
.
assertEqual
(
init_ops
[
0
].
type
,
"fill_constant"
)
self
.
assertEqual
(
init_ops
[
0
].
type
,
"fill_constant"
)
self
.
assertAlmostEqual
(
init_ops
[
0
].
attr
(
'value'
),
learning_rate
)
self
.
assertAlmostEqual
(
init_ops
[
0
].
attr
(
'value'
),
learning_rate
)
...
...
python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py
浏览文件 @
0a1862d1
...
@@ -12,57 +12,97 @@
...
@@ -12,57 +12,97 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet.base.role_maker
as
role_maker
import
unittest
import
unittest
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.distributed.fleet
as
fleet
from
paddle.distributed.fleet.meta_optimizers
import
AMPOptimizer
import
os
import
os
from
fleet_meta_optimizer_base
import
TestFleetMetaOptimizer
paddle
.
enable_static
()
paddle
.
enable_static
()
class
TestFleetAMPOptimizer
(
unittest
.
TestCase
):
class
TestFleetAMPOptimizer
(
TestFleetMetaOptimizer
):
def
setUp
(
self
):
def
test_amp_optimizer_backward
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"0"
""" test amp optimizer backward """
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001"
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
AMPOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertNotIn
(
'check_finite_and_unscale'
,
ops
)
def
test_amp_optimizer_backward_gradients
(
self
):
""" test amp optimizer backward + gradients"""
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
AMPOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
opt
.
apply_gradients
(
params_grads
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
def
test_amp_optimizer_backward_optimize
(
self
):
""" test amp optimizer backward + optimizer """
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
AMPOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
opt
.
apply_optimize
(
avg_cost
,
startup_prog
,
params_grads
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
def
test_amp_optimizer
(
self
):
def
test_amp_optimizer
(
self
):
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
""" test amp """
fleet
.
init
(
role
)
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
input_x
=
paddle
.
fluid
.
layers
.
data
(
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
self
.
set_strategy
(
strategy
,
'amp'
)
input_y
=
paddle
.
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
fc_1
=
paddle
.
fluid
.
layers
.
fc
(
input
=
input_x
,
size
=
64
,
act
=
'tanh'
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
fc_2
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
64
,
act
=
'tanh'
)
self
.
assertIn
(
'cast'
,
ops
)
prediction
=
paddle
.
fluid
.
layers
.
fc
(
input
=
[
fc_2
],
size
=
2
,
act
=
'softmax'
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
cost
=
paddle
.
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
input_y
)
def
test_amp_recompute_optimizer
(
self
):
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
x
=
cost
)
""" test amp + recompute """
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
strategy
.
amp
=
True
self
.
set_strategy
(
strategy
,
'amp'
)
strategy
.
amp_configs
=
{
self
.
set_strategy
(
strategy
,
'recompute'
)
"init_loss_scaling"
:
32768
,
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
"decr_every_n_nan_or_inf"
:
2
,
"incr_every_n_steps"
:
1000
,
"incr_ratio"
:
2.0
,
"use_dynamic_loss_scaling"
:
True
,
"decr_ratio"
:
0.5
,
"custom_white_list"
:
[
'softmax'
],
"custom_black_list"
:
[
'tanh'
],
}
optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.01
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
strategy
=
fleet
.
_final_strategy
()
strategy
=
fleet
.
_final_strategy
()
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
outs
=
[
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
]
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
# recompute
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py
浏览文件 @
0a1862d1
...
@@ -17,65 +17,82 @@ import paddle
...
@@ -17,65 +17,82 @@ import paddle
from
paddle
import
fluid
from
paddle
import
fluid
import
os
import
os
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet
as
fleet
from
fleet_meta_optimizer_base
import
TestFleetMetaOptimizer
from
paddle.distributed.fleet.meta_optimizers
import
DGCOptimizer
import
paddle.distributed.fleet.base.role_maker
as
role_maker
import
paddle.distributed.fleet.base.role_maker
as
role_maker
paddle
.
enable_static
()
class
TestFleetDGCOptimizer
(
unittest
.
TestCase
):
def
setUp
(
self
):
class
TestFleetDGCOptimizer
(
TestFleetMetaOptimizer
):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"1"
def
test_dgc_optimizer_backward
(
self
):
os
.
environ
[
""" test dgc optimizer backward """
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001,127.0.0.1:36002"
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
def
net
(
self
,
main_prog
,
startup_prog
):
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
self
.
set_strategy
(
strategy
,
'dgc'
)
with
fluid
.
unique_name
.
guard
():
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
learning_rate
=
0.001
,
momentum
=
0.9
)
fleet
.
init
(
role
)
dgc_opt
=
DGCOptimizer
(
opt
)
input_x
=
paddle
.
fluid
.
layers
.
data
(
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
dgc_opt
.
_set_basic_info
(
avg_cost
,
role
,
opt
,
strategy
)
input_y
=
paddle
.
fluid
.
layers
.
data
(
params_grads
=
dgc_opt
.
backward
(
avg_cost
,
startup_prog
)
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
fc_1
=
paddle
.
fluid
.
layers
.
fc
(
input
=
input_x
,
self
.
assertNotIn
(
'dgc'
,
ops
)
size
=
64
,
act
=
'tanh'
)
def
test_dgc_optimizer_gradients
(
self
):
fc_2
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
256
,
act
=
'tanh'
)
""" test dgc optimizer backward + gradients """
prediction
=
paddle
.
fluid
.
layers
.
fc
(
input
=
[
fc_2
],
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
size
=
2
,
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
act
=
'softmax'
)
cost
=
paddle
.
fluid
.
layers
.
cross_entropy
(
self
.
set_strategy
(
strategy
,
'dgc'
)
input
=
prediction
,
label
=
input_y
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
x
=
cost
)
learning_rate
=
0.001
,
momentum
=
0.9
)
dgc_opt
=
DGCOptimizer
(
opt
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
strategy
.
dgc
=
True
dgc_opt
.
_set_basic_info
(
avg_cost
,
role
,
opt
,
strategy
)
strategy
.
dgc_configs
=
{
params_grads
=
dgc_opt
.
backward
(
avg_cost
,
startup_prog
)
"rampup_begin_step"
:
128
,
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
"rampup_step"
:
100
,
dgc_opt
.
apply_gradients
(
params_grads
)
"sparsity"
:
[
0.996
,
0.999
]
}
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
return
avg_cost
,
strategy
self
.
assertIn
(
'dgc'
,
ops
)
self
.
assertIn
(
'dgc_momentum'
,
ops
)
def
test_dgc_optimizer_optimize
(
self
):
""" test dgc optimizer backward + optimize """
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'dgc'
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
dgc_opt
=
DGCOptimizer
(
opt
)
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
dgc_opt
.
_set_basic_info
(
avg_cost
,
role
,
opt
,
strategy
)
params_grads
=
dgc_opt
.
backward
(
avg_cost
,
startup_prog
)
dgc_opt
.
apply_optimize
(
avg_cost
,
startup_prog
,
params_grads
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
self
.
assertIn
(
'dgc'
,
ops
)
self
.
assertIn
(
'dgc_momentum'
,
ops
)
def
test_dgc_optimizer
(
self
):
def
test_dgc_optimizer
(
self
):
startup_prog
=
fluid
.
Program
()
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
optimizer
=
paddle
.
fluid
.
optimizer
.
Momentum
(
self
.
set_strategy
(
strategy
,
'dgc'
)
learning_rate
=
0.01
,
momentum
=
0.9
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
self
.
assertIn
(
'dgc'
,
ops
)
self
.
assertIn
(
'dgc'
,
ops
)
self
.
assertIn
(
'dgc_momentum'
,
ops
)
self
.
assertIn
(
'dgc_momentum'
,
ops
)
def
test_dgc_not_apply_with_adam
(
self
):
def
test_dgc_not_apply_with_adam
(
self
):
startup_prog
=
fluid
.
Program
()
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
optimizer
=
paddle
.
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.01
)
self
.
set_strategy
(
strategy
,
'dgc'
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
,
'adam'
)
optimizer
.
minimize
(
avg_cost
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
self
.
assertNotIn
(
'dgc'
,
ops
)
self
.
assertNotIn
(
'dgc'
,
ops
)
...
@@ -85,18 +102,32 @@ class TestFleetDGCOptimizer(unittest.TestCase):
...
@@ -85,18 +102,32 @@ class TestFleetDGCOptimizer(unittest.TestCase):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"0"
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"0"
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001"
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001"
startup_prog
=
fluid
.
Program
()
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
optimizer
=
paddle
.
fluid
.
optimizer
.
Momentum
(
self
.
set_strategy
(
strategy
,
'dgc'
)
learning_rate
=
0.01
,
momentum
=
0.9
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
self
.
assertNotIn
(
'dgc'
,
ops
)
self
.
assertNotIn
(
'dgc'
,
ops
)
self
.
assertNotIn
(
'dgc_momentum'
,
ops
)
self
.
assertNotIn
(
'dgc_momentum'
,
ops
)
def
test_dgc_recompute_optimizer
(
self
):
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'dgc'
)
self
.
set_strategy
(
strategy
,
'recompute'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
outs
=
[
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
]
self
.
assertIn
(
'dgc'
,
ops
)
self
.
assertIn
(
'dgc_momentum'
,
ops
)
# recompute
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py
浏览文件 @
0a1862d1
...
@@ -16,71 +16,87 @@ import unittest
...
@@ -16,71 +16,87 @@ import unittest
import
paddle
import
paddle
import
os
import
os
import
paddle
import
paddle.fluid
as
fluid
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet.base.role_maker
as
role_maker
import
paddle.distributed.fleet.base.role_maker
as
role_maker
from
fleet_meta_optimizer_base
import
TestFleetMetaOptimizer
paddle
.
enable_static
()
class
TestFleetLocalSGDMetaOptimizer
(
unittest
.
TestCase
):
def
setUp
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"1"
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001,127.0.0.1:36002"
class
TestFleetLocalSGDMetaOptimizer
(
TestFleetMetaOptimizer
):
def
test_localsgd_optimizer
(
self
):
def
test_localsgd_optimizer
(
self
):
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
fleet
.
init
(
role
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
input_x
=
paddle
.
fluid
.
layers
.
data
(
self
.
set_strategy
(
strategy
,
'localsgd'
)
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
input_y
=
paddle
.
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
fc
=
paddle
.
fluid
.
layers
.
fc
(
input
=
input_x
,
size
=
64
,
act
=
'tanh'
)
outs
=
[
prediction
=
paddle
.
fluid
.
layers
.
fc
(
input
=
[
fc
],
size
=
2
,
act
=
'softmax'
)
''
.
join
(
op
.
output
(
'Out'
))
for
op
in
avg_cost
.
block
.
ops
cost
=
paddle
.
fluid
.
layers
.
cross_entropy
(
if
op
.
type
==
'conditional_block'
input
=
prediction
,
label
=
input_y
)
]
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
x
=
cost
)
self
.
assertIn
(
'conditional_block'
,
ops
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
self
.
assertIn
(
'@SNAPSHOT'
,
''
.
join
(
outs
))
strategy
.
localsgd
=
True
strategy
.
auto
=
True
def
test_localsgd_amp_optimizer
(
self
):
config
=
strategy
.
localsgd_configs
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
config
[
'k_steps'
]
=
1
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
config
[
'begin_step'
]
=
1
self
.
set_strategy
(
strategy
,
'localsgd'
)
strategy
.
localsgd_configs
=
config
self
.
set_strategy
(
strategy
,
'amp'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.01
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
optimizer
.
minimize
(
avg_cost
)
outs
=
[
''
.
join
(
op
.
output
(
'Out'
))
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'conditional_block'
class
TestFleetAdaptiveLocalSGDMetaOptimizer
(
unittest
.
TestCase
):
]
def
setUp
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"1"
self
.
assertIn
(
'conditional_block'
,
ops
)
os
.
environ
[
self
.
assertIn
(
'@SNAPSHOT'
,
''
.
join
(
outs
))
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001,127.0.0.1:36002"
# amp
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
class
TestFleetAdaptiveLocalSGDMetaOptimizer
(
TestFleetMetaOptimizer
):
def
test_adaptive_localsgd_optimizer
(
self
):
def
test_adaptive_localsgd_optimizer
(
self
):
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
fleet
.
init
(
role
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
input_x
=
paddle
.
fluid
.
layers
.
data
(
self
.
set_strategy
(
strategy
,
'adaptive_localsgd'
)
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
input_y
=
paddle
.
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
fc
=
paddle
.
fluid
.
layers
.
fc
(
input
=
input_x
,
size
=
64
,
act
=
'tanh'
)
outs
=
[
prediction
=
paddle
.
fluid
.
layers
.
fc
(
input
=
[
fc
],
size
=
2
,
act
=
'softmax'
)
''
.
join
(
op
.
output
(
'Out'
))
for
op
in
avg_cost
.
block
.
ops
cost
=
paddle
.
fluid
.
layers
.
cross_entropy
(
if
op
.
type
==
'conditional_block'
input
=
prediction
,
label
=
input_y
)
]
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
x
=
cost
)
self
.
assertIn
(
'conditional_block'
,
ops
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
self
.
assertIn
(
'@SNAPSHOT'
,
''
.
join
(
outs
))
strategy
.
adaptive_localsgd
=
True
config
=
strategy
.
adaptive_localsgd_configs
def
test_localsgd_amp_optimizer
(
self
):
config
[
'init_k_steps'
]
=
1
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
config
[
'begin_step'
]
=
1
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
strategy
.
adaptive_localsgd_configs
=
config
self
.
set_strategy
(
strategy
,
'adaptive_localsgd'
)
self
.
set_strategy
(
strategy
,
'amp'
)
optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.01
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
outs
=
[
''
.
join
(
op
.
output
(
'Out'
))
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'conditional_block'
]
self
.
assertIn
(
'conditional_block'
,
ops
)
self
.
assertIn
(
'@SNAPSHOT'
,
''
.
join
(
outs
))
# amp
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py
浏览文件 @
0a1862d1
...
@@ -14,40 +14,144 @@
...
@@ -14,40 +14,144 @@
import
unittest
import
unittest
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
os
import
os
from
fleet_meta_optimizer_base
import
TestFleetMetaOptimizer
from
paddle.distributed.fleet.meta_optimizers
import
RecomputeOptimizer
paddle
.
enable_static
()
class
TestFleetRecomputeMetaOptimizer
(
unittest
.
TestCase
):
def
setUp
(
self
):
class
TestFleetRecomputeMetaOptimizer
(
TestFleetMetaOptimizer
):
os
.
environ
[
"POD_IP"
]
=
"127.0.0.1"
def
test_recompute_optimizer_backward
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001"
""" test recompute optimizer backward """
os
.
environ
[
"PADDLE_TRAINERS_NUM"
]
=
"2"
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
\
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
"127.0.0.1:36001,127.0.0.2:36001"
self
.
set_strategy
(
strategy
,
'recompute'
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
RecomputeOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
outs
=
[
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
]
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
def
test_recompute_optimizer_backward_gradients
(
self
):
""" test recompute optimizer backward + gradients """
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'recompute'
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
RecomputeOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
opt
.
apply_gradients
(
params_grads
)
outs
=
[
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
]
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
def
test_recompute_optimizer_backward_optimize
(
self
):
""" test recompute optimizer backward + optimize """
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'recompute'
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
RecomputeOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
opt
.
apply_optimize
(
avg_cost
,
startup_prog
,
params_grads
)
outs
=
[
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
]
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
def
test_recompute_optimizer_backward
(
self
):
""" test recompute optimizer backward """
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'recompute'
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
RecomputeOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
outs
=
[
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
]
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
def
test_recompute_optimizer_backward
(
self
):
""" test recompute optimizer backward """
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'recompute'
)
opt
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
opt
=
RecomputeOptimizer
(
opt
)
opt
.
user_defined_strategy
=
strategy
params_grads
=
opt
.
backward
(
avg_cost
,
startup_prog
)
outs
=
[
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
]
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
def
test_recompute_optimizer
(
self
):
def
test_recompute_optimizer
(
self
):
import
paddle.distributed.fleet
as
fleet
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
import
paddle.distributed.fleet.base.role_maker
as
role_maker
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
self
.
set_strategy
(
strategy
,
'recompute'
)
fleet
.
init
(
role
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
input_x
=
paddle
.
fluid
.
layers
.
data
(
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
outs
=
[
input_y
=
paddle
.
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
]
fc_1
=
paddle
.
fluid
.
layers
.
fc
(
input
=
input_x
,
size
=
64
,
act
=
'tanh'
)
fc_2
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
64
,
act
=
'tanh'
)
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
prediction
=
paddle
.
fluid
.
layers
.
fc
(
input
=
[
fc_2
],
size
=
2
,
act
=
'softmax'
)
cost
=
paddle
.
fluid
.
layers
.
cross_entropy
(
def
test_recompute_lars_optimizer
(
self
):
input
=
prediction
,
label
=
input_y
)
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
x
=
cost
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'recompute'
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
self
.
set_strategy
(
strategy
,
'lars'
)
strategy
.
recompute
=
True
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
strategy
.
recompute_configs
=
{
"checkpoints"
:
[
"fc_1.tmp_0"
]}
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.01
)
outs
=
[
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
optimizer
.
minimize
(
avg_cost
)
]
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
self
.
assertIn
(
'lars_momentum'
,
ops
)
def
test_recompute_lamb_optimizer
(
self
):
train_prog
,
startup_prog
=
fluid
.
Program
(),
fluid
.
Program
()
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'recompute'
)
self
.
set_strategy
(
strategy
,
'lamb'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
,
'adam'
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
outs
=
[
op
.
output
(
'Out'
)[
0
]
for
op
in
avg_cost
.
block
.
ops
if
op
.
type
==
'mul'
]
self
.
assertIn
(
'subprog'
,
''
.
join
(
outs
))
self
.
assertIn
(
'lamb'
,
ops
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录