未验证 提交 54b81fa3 编写于 作者: S ShenLiang 提交者: GitHub

add adaptivelsgd in meta_optimizer (#27289)

* add adaptivelsgd

* Todo fix the code to avoid the conflict.
上级 6e29c2da
......@@ -41,6 +41,11 @@ message LocalSGDConfig {
optional int32 begin_step = 2 [ default = 1 ];
}
message AdaptiveLocalSGDConfig {
optional int32 init_k_steps = 1 [ default = 1 ];
optional int32 begin_step = 2 [ default = 1 ];
}
message GradientMergeConfig {
optional int32 k_steps = 1 [ default = 1 ];
optional bool avg = 2 [ default = true ];
......@@ -121,6 +126,7 @@ message DistributedStrategy {
optional bool cudnn_exhaustive_search = 21 [ default = true ];
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......@@ -131,6 +137,7 @@ message DistributedStrategy {
optional AsyncConfig a_sync_configs = 107;
optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
}
......
......@@ -728,6 +728,63 @@ class DistributedStrategy(object):
"localsgd_configs")
assign_configs_value(self.strategy.localsgd_configs, configs)
@property
def adaptive_localsgd(self):
"""
Indicating whether we are using Adaptive Local SGD training. Default Value: False
For more details, please refer to `Adaptive Communication Strategies to Achieve
the Best Error-Runtime Trade-off in Local-Update SGD <https://arxiv.org/pdf/1810.08313.pdf>`_.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.adaptive_localsgd = True # by default this is false
"""
return self.strategy.localsgd
@adaptive_localsgd.setter
@is_strict_auto
def adaptive_localsgd(self, flag):
if isinstance(flag, bool):
self.strategy.localsgd = flag
else:
print("WARNING: adaptive_localsgd should have value of bool type")
@property
def adaptive_localsgd_configs(self):
"""
Set AdaptiveLocalSGD training configurations. AdaptiveLocalSGD has a configurable
setting that can be configured through a dict.
**Notes**:
init_k_steps(int) The initial steps for training before adaptive localsgd.
Then, the adaptive localsgd method will modify init_k_steps automatically.
Default 1.
begin_step(int) The step of begining training by adaptive localsgd. Default 1.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.adaptive_localsgd = True
strategy.adaptive_localsgd_configs = {"init_k_steps": 1,
"begin_step": 30}
"""
return get_msg_dict(self.strategy.adaptive_localsgd_configs)
@adaptive_localsgd_configs.setter
@is_strict_auto
def adaptive_localsgd_configs(self, configs):
check_configs_key(self.strategy.adaptive_localsgd_configs, configs,
"adaptive_localsgd_configs")
assign_configs_value(self.strategy.adaptive_localsgd_configs, configs)
@property
def dgc(self):
"""
......
......@@ -18,6 +18,7 @@ from .graph_execution_optimizer import GraphExecutionOptimizer
from .parameter_server_optimizer import ParameterServerOptimizer
from .pipeline_optimizer import PipelineOptimizer
from .localsgd_optimizer import LocalSGDOptimizer
from .localsgd_optimizer import AdaptiveLocalSGDOptimizer
from .lars_optimizer import LarsOptimizer
from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from .dgc_optimizer import DGCOptimizer
......
......@@ -24,7 +24,7 @@ class AMPOptimizer(MetaOptimizerBase):
self.meta_optimizers_white_list = [
"LarsOptimizer", "LambOptimizer", "RecomputeOptimizer",
"LocalSGDOptimizer", "GradientMergeOptimizer",
"GraphExecutionOptimizer"
"GraphExecutionOptimizer", "AdaptiveLocalSGDOptimizer"
]
self.meta_optimizers_black_list = ["DGCOptimizer"]
......
......@@ -25,7 +25,9 @@ class LocalSGDOptimizer(MetaOptimizerBase):
super(LocalSGDOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self.meta_optimizers_white_list = []
self.meta_optimizers_black_list = ["GraphExecutionOptimizer"]
self.meta_optimizers_black_list = [
"GraphExecutionOptimizer", "AdaptiveLocalSGDOptimizer"
]
self.snapshot_key = '@SNAPSHOT'
def _can_apply(self):
......@@ -186,3 +188,252 @@ class LocalSGDOptimizer(MetaOptimizerBase):
layers.cond(step > begin_step, begin_localsgd, communicate)
return minimized
class AdaptiveLocalSGDOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(AdaptiveLocalSGDOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self.meta_optimizers_white_list = []
self.meta_optimizers_black_list = [
"GraphExecutionOptimizer", "LocalSGDOptimizer"
]
self.snapshot_key = '@SNAPSHOT'
def _can_apply(self):
if not self.role_maker._is_collective:
return False
if not self.user_defined_strategy.adaptive_localsgd:
return False
if self.role_maker.worker_num() <= 1:
return False
return isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum) \
or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) \
or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) \
or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD)
def _disable_strategy(self, dist_strategy):
dist_strategy.adaptive_localsgd = False
dist_strategy.adaptive_localsgd_configs = {}
def _enable_strategy(self, dist_strategy, context):
dist_strategy.adaptive_localsgd = True
dist_strategy.adaptive_localsgd_configs = {
"init_k_steps": 1,
"begin_step": 1
}
def snapshot_name(self, param_name):
return param_name + self.snapshot_key
def create_snapshot_vars(self, program):
block = program.global_block()
non_dist_params = []
for param in block.iter_parameters():
if not param.is_distributed:
non_dist_params.append(param)
p2s = []
for param in non_dist_params:
snapshot = block.create_var(
name=self.snapshot_name(param.name),
shape=param.shape,
persistable=True,
stop_gradient=True,
dtype=param.dtype)
p2s.append([param, snapshot])
return p2s
def init_snapshot_vars(self, startup_program, param2snapshot):
with program_guard(startup_program):
for param, snapshot in param2snapshot:
layers.assign(param, snapshot)
def _generate_avg_loss(self, program_block, loss, avg_loss):
program_block.append_op(
type='c_allreduce_sum',
inputs={'X': [loss]},
outputs={'Out': [avg_loss]},
attrs={
'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize,
'use_calc_stream': True
})
program_block.append_op(
type='c_sync_calc_stream',
inputs={'X': [avg_loss]},
outputs={'Out': [avg_loss]},
attrs={OP_ROLE_KEY: OpRole.Optimize})
program_block.append_op(
type='scale',
inputs={'X': [avg_loss]},
outputs={'Out': [avg_loss]},
attrs={
'scale': 1.0 / self.role_maker.worker_num(),
OP_ROLE_KEY: OpRole.Optimize
})
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
minimized = self.inner_opt.minimize(
loss, startup_program=startup_program)
init_k_steps = self.user_defined_strategy.adaptive_localsgd_configs[
'init_k_steps']
begin_step_value = self.user_defined_strategy.adaptive_localsgd_configs[
'begin_step']
if startup_program is None:
startup_program = default_startup_program()
main_block = loss.block
self.nrings = 2
collective_helper = CollectiveHelper(self.role_maker, self.nrings)
collective_helper.update_startup_program(startup_program)
p2s = self.create_snapshot_vars(startup_program)
self.init_snapshot_vars(startup_program, p2s)
p2s = self.create_snapshot_vars(main_block.program)
with program_guard(main_block.program, startup_program):
step = layers.autoincreased_step_counter(begin=1)
k_steps = layers.create_global_var(
name="k_steps",
shape=[1],
value=int(init_k_steps),
dtype='int64',
persistable=True)
begin_step = layers.create_global_var(
name="begin_step",
shape=[1],
value=int(begin_step_value),
dtype='int64',
persistable=True)
last_step = layers.create_global_var(
name="last_step",
shape=[1],
value=int(0),
dtype='int64',
persistable=True)
avg_loss = layers.create_global_var(
name="avg_loss",
shape=[1],
value=float(0),
dtype=loss.dtype,
persistable=True)
lr_0 = layers.create_global_var(
name="lr_0",
shape=[1],
value=float(0),
dtype='float32',
persistable=True)
loss_0 = layers.create_global_var(
name="loss_0",
shape=[1],
value=float(0),
dtype='float32',
persistable=True)
global_lr = self.inner_opt._global_learning_rate()
def initialize():
self._generate_avg_loss(main_block, loss, avg_loss)
layers.assign(avg_loss, loss_0)
layers.assign(global_lr, lr_0)
layers.cond(step == 1, initialize)
def communicate():
sub_block = default_main_program().current_block()
ring_id = -1
for param, snapshot in p2s:
sub_block.append_op(
type='elementwise_sub',
inputs={'X': [snapshot],
'Y': [param]},
outputs={'Out': [param]},
attrs={OP_ROLE_KEY: OpRole.Optimize})
sub_block.append_op(
type='c_sync_calc_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={OP_ROLE_KEY: OpRole.Optimize})
ring_id = (ring_id + 1) % self.nrings
sub_block.append_op(
type='c_allreduce_sum',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
for ring_id in range(self.nrings):
sub_block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
for param, snapshot in p2s:
sub_block.append_op(
type='scale',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'scale': 1.0 / self.role_maker.worker_num(),
OP_ROLE_KEY: OpRole.Optimize
})
sub_block.append_op(
type='elementwise_sub',
inputs={'X': [snapshot],
'Y': [param]},
outputs={'Out': [param]},
attrs={OP_ROLE_KEY: OpRole.Optimize})
sub_block.append_op(
type='assign',
inputs={'X': [param]},
outputs={'Out': [snapshot]},
attrs={OP_ROLE_KEY: OpRole.Optimize})
layers.assign(step, last_step)
def communicate_avg_loss():
communicate()
self._generate_avg_loss(main_block, loss, avg_loss)
next_local_steps = layers.cast(
layers.ceil(
layers.sqrt(lr_0 * avg_loss / (global_lr * loss_0) *
float(init_k_steps))),
dtype='int64')
max_local_steps = layers.fill_constant(
shape=[1], dtype='int64', value=16)
min_local_steps = layers.fill_constant(
shape=[1], dtype='int64', value=1)
next_local_steps = layers.elementwise_min(next_local_steps,
max_local_steps)
next_local_steps = layers.elementwise_max(next_local_steps,
min_local_steps)
layers.assign(next_local_steps, k_steps)
def begin_localsgd():
layers.cond(step - last_step == k_steps, communicate_avg_loss)
layers.cond(step > begin_step, begin_localsgd, communicate)
return minimized
......@@ -86,6 +86,13 @@ class TestStrategyConfig(unittest.TestCase):
self.assertEqual(strategy.localsgd_configs["k_steps"], 4)
self.assertEqual(strategy.localsgd_configs["begin_step"], 120)
def test_adaptive_localsgd_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {"init_k_steps": 1, "begin_step": 120}
strategy.adaptive_localsgd_configs = configs
self.assertEqual(strategy.adaptive_localsgd_configs["init_k_steps"], 1)
self.assertEqual(strategy.adaptive_localsgd_configs["begin_step"], 120)
def test_dgc(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.dgc = True
......
......@@ -52,5 +52,36 @@ class TestFleetLocalSGDMetaOptimizer(unittest.TestCase):
optimizer.minimize(avg_cost)
class TestFleetAdaptiveLocalSGDMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "1"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"
def test_adaptive_localsgd_optimizer(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.adaptive_localsgd = True
config = strategy.adaptive_localsgd_configs
config['init_k_steps'] = 1
config['begin_step'] = 1
strategy.adaptive_localsgd_configs = config
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册