From 3b2c580a6658ff570980036f2e6c2a0a57134b5b Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Mon, 17 Aug 2020 18:45:08 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90paddle.fleet=E3=80=91make=20fleet=5Flo?= =?UTF-8?q?calsgd=5Fmeta=5Foptimizer=20work=20(#26213)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * make fleet_localsgd_meta_optimizer work * fix bug in localsgd meta optimizer --- .../meta_optimizers/localsgd_optimizer.py | 107 ++++++++++-------- 1 file changed, 59 insertions(+), 48 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index 05a120f8163..9a5c6745164 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -14,7 +14,7 @@ from __future__ import print_function -from paddle.fluid import program_guard, layers +from paddle.fluid import program_guard, layers, default_main_program from paddle.fluid.optimizer import Momentum, SGD from .meta_optimizer_base import MetaOptimizerBase from .common import OpRole, OP_ROLE_KEY, CollectiveHelper, is_update_op @@ -44,6 +44,30 @@ class LocalSGDOptimizer(MetaOptimizerBase): def snapshot_name(self, param_name): return param_name + self.snapshot_key + def create_snapshot_vars(self, program): + block = program.global_block() + + non_dist_params = [] + for param in block.iter_parameters(): + if not param.is_distributed: + non_dist_params.append(param) + + p2s = [] + for param in non_dist_params: + snapshot = block.create_var( + name=self.snapshot_name(param.name), + shape=param.shape, + persistable=True, + stop_gradient=True, + dtype=param.dtype) + p2s.append([param, snapshot]) + return p2s + + def init_snapshot_vars(self, startup_program, param2snapshot): + with program_guard(startup_program): + for param, snapshot in param2snapshot: + layers.assign(param, snapshot) + def minimize_impl(self, loss, startup_program=None, @@ -62,8 +86,11 @@ class LocalSGDOptimizer(MetaOptimizerBase): self.nrings = 2 collective_helper = CollectiveHelper(self.role_maker, self.nrings) collective_helper.update_startup_program(startup_program) + p2s = self.create_snapshot_vars(startup_program) + self.init_snapshot_vars(startup_program, p2s) - with program_guard(main_block.program): + p2s = self.create_snapshot_vars(main_block.program) + with program_guard(main_block.program, startup_program): step = layers.autoincreased_step_counter(begin=0) k_steps = layers.create_global_var( name="k_steps", @@ -79,6 +106,9 @@ class LocalSGDOptimizer(MetaOptimizerBase): persistable=True) if auto_steps: + avg_loss = layers.collective._c_allreduce( + loss) / self.role_maker.worker_num() + lr_0 = layers.create_global_var( name="lr_0", shape=[1], @@ -101,49 +131,32 @@ class LocalSGDOptimizer(MetaOptimizerBase): layers.cond(step == 0, initialize) def communicate(): - ordered_param_snapshot = [] + sub_block = default_main_program().current_block() ring_id = -1 - for idx, op in reversed(list(enumerate(main_block.ops))): - if is_update_op(op): - param = main_block.vars[op.input('Param')[0]] - if param.is_distributed: - continue - - snapshot = main_block.create_var( - name=self.snapshot_name(param.name), - shape=param.shape, - persistable=True, - stop_gradient=True, - dtype=param.dtype) - - main_block._insert_op( - idx + 1, - type='elementwise_sub', - inputs={'X': [snapshot], - 'Y': [param]}, - outputs={'Out': [param]}, - attrs={OP_ROLE_KEY: OpRole.Optimize}) - main_block._insert_op( - idx + 2, - type='c_sync_calc_stream', - inputs={'X': param}, - outputs={'Out': param}, - attrs={OP_ROLE_KEY: OpRole.Optimize}) - ring_id = (ring_id + 1) % self.nrings - main_block._insert_op( - idx + 3, - type='c_allreduce_sum', - inputs={'X': [param]}, - outputs={'Out': [param]}, - attrs={ - 'ring_id': ring_id, - OP_ROLE_KEY: OpRole.Optimize - }) - - ordered_param_snapshot.append((param, snapshot)) + for param, snapshot in p2s: + sub_block.append_op( + type='elementwise_sub', + inputs={'X': [snapshot], + 'Y': [param]}, + outputs={'Out': [param]}, + attrs={OP_ROLE_KEY: OpRole.Optimize}) + sub_block.append_op( + type='c_sync_calc_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={OP_ROLE_KEY: OpRole.Optimize}) + ring_id = (ring_id + 1) % self.nrings + sub_block.append_op( + type='c_allreduce_sum', + inputs={'X': [param]}, + outputs={'Out': [param]}, + attrs={ + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Optimize + }) for ring_id in range(self.nrings): - main_block.append_op( + sub_block.append_op( type='c_sync_comm_stream', inputs={'X': param}, outputs={'Out': param}, @@ -152,10 +165,8 @@ class LocalSGDOptimizer(MetaOptimizerBase): OP_ROLE_KEY: OpRole.Optimize }) - for param_snapshot in reversed(ordered_param_snapshot): - param = param_snapshot[0] - snapshot = param_snapshot[1] - main_block.append_op( + for param, snapshot in p2s: + sub_block.append_op( type='scale', inputs={'X': [param]}, outputs={'Out': [param]}, @@ -163,13 +174,13 @@ class LocalSGDOptimizer(MetaOptimizerBase): 'scale': 1.0 / self.role_maker.worker_num(), OP_ROLE_KEY: OpRole.Optimize }) - main_block.append_op( + sub_block.append_op( type='elementwise_sub', inputs={'X': [snapshot], 'Y': [param]}, outputs={'Out': [param]}, attrs={OP_ROLE_KEY: OpRole.Optimize}) - main_block.append_op( + sub_block.append_op( type='assign', inputs={'X': [param]}, outputs={'Out': [snapshot]}, -- GitLab