未验证 提交 3b2c580a 编写于 作者: Y Yi Liu 提交者: GitHub

【paddle.fleet】make fleet_localsgd_meta_optimizer work (#26213)

* make fleet_localsgd_meta_optimizer work

* fix bug in localsgd meta optimizer
上级 d549a9b1
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import print_function from __future__ import print_function
from paddle.fluid import program_guard, layers from paddle.fluid import program_guard, layers, default_main_program
from paddle.fluid.optimizer import Momentum, SGD from paddle.fluid.optimizer import Momentum, SGD
from .meta_optimizer_base import MetaOptimizerBase from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, CollectiveHelper, is_update_op from .common import OpRole, OP_ROLE_KEY, CollectiveHelper, is_update_op
...@@ -44,6 +44,30 @@ class LocalSGDOptimizer(MetaOptimizerBase): ...@@ -44,6 +44,30 @@ class LocalSGDOptimizer(MetaOptimizerBase):
def snapshot_name(self, param_name): def snapshot_name(self, param_name):
return param_name + self.snapshot_key return param_name + self.snapshot_key
def create_snapshot_vars(self, program):
block = program.global_block()
non_dist_params = []
for param in block.iter_parameters():
if not param.is_distributed:
non_dist_params.append(param)
p2s = []
for param in non_dist_params:
snapshot = block.create_var(
name=self.snapshot_name(param.name),
shape=param.shape,
persistable=True,
stop_gradient=True,
dtype=param.dtype)
p2s.append([param, snapshot])
return p2s
def init_snapshot_vars(self, startup_program, param2snapshot):
with program_guard(startup_program):
for param, snapshot in param2snapshot:
layers.assign(param, snapshot)
def minimize_impl(self, def minimize_impl(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -62,8 +86,11 @@ class LocalSGDOptimizer(MetaOptimizerBase): ...@@ -62,8 +86,11 @@ class LocalSGDOptimizer(MetaOptimizerBase):
self.nrings = 2 self.nrings = 2
collective_helper = CollectiveHelper(self.role_maker, self.nrings) collective_helper = CollectiveHelper(self.role_maker, self.nrings)
collective_helper.update_startup_program(startup_program) collective_helper.update_startup_program(startup_program)
p2s = self.create_snapshot_vars(startup_program)
self.init_snapshot_vars(startup_program, p2s)
with program_guard(main_block.program): p2s = self.create_snapshot_vars(main_block.program)
with program_guard(main_block.program, startup_program):
step = layers.autoincreased_step_counter(begin=0) step = layers.autoincreased_step_counter(begin=0)
k_steps = layers.create_global_var( k_steps = layers.create_global_var(
name="k_steps", name="k_steps",
...@@ -79,6 +106,9 @@ class LocalSGDOptimizer(MetaOptimizerBase): ...@@ -79,6 +106,9 @@ class LocalSGDOptimizer(MetaOptimizerBase):
persistable=True) persistable=True)
if auto_steps: if auto_steps:
avg_loss = layers.collective._c_allreduce(
loss) / self.role_maker.worker_num()
lr_0 = layers.create_global_var( lr_0 = layers.create_global_var(
name="lr_0", name="lr_0",
shape=[1], shape=[1],
...@@ -101,49 +131,32 @@ class LocalSGDOptimizer(MetaOptimizerBase): ...@@ -101,49 +131,32 @@ class LocalSGDOptimizer(MetaOptimizerBase):
layers.cond(step == 0, initialize) layers.cond(step == 0, initialize)
def communicate(): def communicate():
ordered_param_snapshot = [] sub_block = default_main_program().current_block()
ring_id = -1 ring_id = -1
for idx, op in reversed(list(enumerate(main_block.ops))): for param, snapshot in p2s:
if is_update_op(op): sub_block.append_op(
param = main_block.vars[op.input('Param')[0]] type='elementwise_sub',
if param.is_distributed: inputs={'X': [snapshot],
continue 'Y': [param]},
outputs={'Out': [param]},
snapshot = main_block.create_var( attrs={OP_ROLE_KEY: OpRole.Optimize})
name=self.snapshot_name(param.name), sub_block.append_op(
shape=param.shape, type='c_sync_calc_stream',
persistable=True, inputs={'X': param},
stop_gradient=True, outputs={'Out': param},
dtype=param.dtype) attrs={OP_ROLE_KEY: OpRole.Optimize})
ring_id = (ring_id + 1) % self.nrings
main_block._insert_op( sub_block.append_op(
idx + 1, type='c_allreduce_sum',
type='elementwise_sub', inputs={'X': [param]},
inputs={'X': [snapshot], outputs={'Out': [param]},
'Y': [param]}, attrs={
outputs={'Out': [param]}, 'ring_id': ring_id,
attrs={OP_ROLE_KEY: OpRole.Optimize}) OP_ROLE_KEY: OpRole.Optimize
main_block._insert_op( })
idx + 2,
type='c_sync_calc_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={OP_ROLE_KEY: OpRole.Optimize})
ring_id = (ring_id + 1) % self.nrings
main_block._insert_op(
idx + 3,
type='c_allreduce_sum',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
ordered_param_snapshot.append((param, snapshot))
for ring_id in range(self.nrings): for ring_id in range(self.nrings):
main_block.append_op( sub_block.append_op(
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': param}, inputs={'X': param},
outputs={'Out': param}, outputs={'Out': param},
...@@ -152,10 +165,8 @@ class LocalSGDOptimizer(MetaOptimizerBase): ...@@ -152,10 +165,8 @@ class LocalSGDOptimizer(MetaOptimizerBase):
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
for param_snapshot in reversed(ordered_param_snapshot): for param, snapshot in p2s:
param = param_snapshot[0] sub_block.append_op(
snapshot = param_snapshot[1]
main_block.append_op(
type='scale', type='scale',
inputs={'X': [param]}, inputs={'X': [param]},
outputs={'Out': [param]}, outputs={'Out': [param]},
...@@ -163,13 +174,13 @@ class LocalSGDOptimizer(MetaOptimizerBase): ...@@ -163,13 +174,13 @@ class LocalSGDOptimizer(MetaOptimizerBase):
'scale': 1.0 / self.role_maker.worker_num(), 'scale': 1.0 / self.role_maker.worker_num(),
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
main_block.append_op( sub_block.append_op(
type='elementwise_sub', type='elementwise_sub',
inputs={'X': [snapshot], inputs={'X': [snapshot],
'Y': [param]}, 'Y': [param]},
outputs={'Out': [param]}, outputs={'Out': [param]},
attrs={OP_ROLE_KEY: OpRole.Optimize}) attrs={OP_ROLE_KEY: OpRole.Optimize})
main_block.append_op( sub_block.append_op(
type='assign', type='assign',
inputs={'X': [param]}, inputs={'X': [param]},
outputs={'Out': [snapshot]}, outputs={'Out': [snapshot]},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册