localsgd_optimizer.py 17.1 KB
Newer Older
Y
Yi Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
ShenLiang 已提交
15
import paddle
16 17 18 19 20 21 22
from paddle.static import (
    default_main_program,
    default_startup_program,
    program_guard,
)

from .common import OP_ROLE_KEY, CollectiveHelper, OpRole
Y
Yi Liu 已提交
23 24
from .meta_optimizer_base import MetaOptimizerBase

25 26
__all__ = []

Y
Yi Liu 已提交
27 28 29

class LocalSGDOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
30
        super().__init__(optimizer)
Y
Yi Liu 已提交
31
        self.inner_opt = optimizer
32
        self.meta_optimizers_white_list = ['AMPOptimizer']
33
        self.meta_optimizers_black_list = [
34 35
            "GraphExecutionOptimizer",
            "AdaptiveLocalSGDOptimizer",
36
        ]
Y
Yi Liu 已提交
37 38 39
        self.snapshot_key = '@SNAPSHOT'

    def _can_apply(self):
40 41 42
        if not self.role_maker._is_collective:
            return False

Y
Yi Liu 已提交
43 44 45
        if not self.user_defined_strategy.localsgd:
            return False

46
        if self.role_maker._worker_num() <= 1:
Y
Yi Liu 已提交
47 48
            return False

49 50 51 52
        return (
            isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum)
            or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum)
            or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD)
53
            or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD)
54
        )
Y
Yi Liu 已提交
55 56 57

    def _disable_strategy(self, dist_strategy):
        dist_strategy.localsgd = False
58
        dist_strategy.localsgd_configs = {}
Y
Yi Liu 已提交
59

60
    def _enable_strategy(self, dist_strategy, context):
61
        dist_strategy.localsgd = True
62
        dist_strategy.localsgd_configs = {"k_steps": 1, "begin_step": 1}
63

Y
Yi Liu 已提交
64 65 66
    def snapshot_name(self, param_name):
        return param_name + self.snapshot_key

67 68 69 70 71 72 73 74 75 76
    def create_snapshot_vars(self, program):
        block = program.global_block()

        non_dist_params = []
        for param in block.iter_parameters():
            if not param.is_distributed:
                non_dist_params.append(param)

        p2s = []
        for param in non_dist_params:
77 78 79 80 81 82 83
            snapshot = block.create_var(
                name=self.snapshot_name(param.name),
                shape=param.shape,
                persistable=True,
                stop_gradient=True,
                dtype=param.dtype,
            )
84 85 86 87 88 89
            p2s.append([param, snapshot])
        return p2s

    def init_snapshot_vars(self, startup_program, param2snapshot):
        with program_guard(startup_program):
            for param, snapshot in param2snapshot:
90
                paddle.assign(param, snapshot)
91

92 93 94 95 96 97
    def minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
        minimized = self.inner_opt.minimize(
            loss, startup_program=startup_program
        )
Y
Yi Liu 已提交
98

99 100
        k_steps_value = self.user_defined_strategy.localsgd_configs['k_steps']
        begin_step_value = self.user_defined_strategy.localsgd_configs[
101 102
            'begin_step'
        ]
Y
Yi Liu 已提交
103 104 105 106 107 108 109 110

        if startup_program is None:
            startup_program = default_startup_program()
        main_block = loss.block

        self.nrings = 2
        collective_helper = CollectiveHelper(self.role_maker, self.nrings)
        collective_helper.update_startup_program(startup_program)
111 112
        p2s = self.create_snapshot_vars(startup_program)
        self.init_snapshot_vars(startup_program, p2s)
Y
Yi Liu 已提交
113

114 115
        p2s = self.create_snapshot_vars(main_block.program)
        with program_guard(main_block.program, startup_program):
116 117
            step = paddle.fluid.layers.autoincreased_step_counter(begin=1)
            k_steps = paddle.static.create_global_var(
118 119 120 121 122 123 124
                name="k_steps",
                shape=[1],
                value=k_steps_value,
                dtype='int64',
                persistable=True,
            )

125
            begin_step = paddle.static.create_global_var(
126 127 128 129 130 131 132
                name="begin_step",
                shape=[1],
                value=begin_step_value,
                dtype='int64',
                persistable=True,
            )

133
            last_step = paddle.static.create_global_var(
134 135 136 137 138 139
                name="last_step",
                shape=[1],
                value=begin_step_value,
                dtype='int64',
                persistable=True,
            )
Y
Yi Liu 已提交
140 141

            def communicate():
142
                sub_block = default_main_program().current_block()
Y
Yi Liu 已提交
143
                ring_id = -1
144
                for param, snapshot in p2s:
145 146 147 148 149 150 151 152 153 154 155 156
                    sub_block.append_op(
                        type='elementwise_sub',
                        inputs={'X': [snapshot], 'Y': [param]},
                        outputs={'Out': [param]},
                        attrs={OP_ROLE_KEY: OpRole.Optimize},
                    )
                    sub_block.append_op(
                        type='c_sync_calc_stream',
                        inputs={'X': param},
                        outputs={'Out': param},
                        attrs={OP_ROLE_KEY: OpRole.Optimize},
                    )
157
                    ring_id = (ring_id + 1) % self.nrings
158 159 160 161 162 163 164 165 166
                    sub_block.append_op(
                        type='c_allreduce_sum',
                        inputs={'X': [param]},
                        outputs={'Out': [param]},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Optimize,
                        },
                    )
Y
Yi Liu 已提交
167 168

                for ring_id in range(self.nrings):
169 170 171 172 173 174 175 176 177
                    sub_block.append_op(
                        type='c_sync_comm_stream',
                        inputs={'X': param},
                        outputs={'Out': param},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Optimize,
                        },
                    )
Y
Yi Liu 已提交
178

179
                for param, snapshot in p2s:
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
                    sub_block.append_op(
                        type='scale',
                        inputs={'X': [param]},
                        outputs={'Out': [param]},
                        attrs={
                            'scale': 1.0 / self.role_maker._worker_num(),
                            OP_ROLE_KEY: OpRole.Optimize,
                        },
                    )
                    sub_block.append_op(
                        type='elementwise_sub',
                        inputs={'X': [snapshot], 'Y': [param]},
                        outputs={'Out': [param]},
                        attrs={OP_ROLE_KEY: OpRole.Optimize},
                    )
                    sub_block.append_op(
                        type='assign',
                        inputs={'X': [param]},
                        outputs={'Out': [snapshot]},
                        attrs={OP_ROLE_KEY: OpRole.Optimize},
                    )
201
                paddle.assign(step, last_step)
Y
Yi Liu 已提交
202

203
            def begin_localsgd():
204
                paddle.static.nn.cond(step - last_step == k_steps, communicate)
Y
Yi Liu 已提交
205

206 207 208
            paddle.static.nn.cond(
                step > begin_step, begin_localsgd, communicate
            )
Y
Yi Liu 已提交
209
        return minimized
210 211 212 213


class AdaptiveLocalSGDOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
214
        super().__init__(optimizer)
215
        self.inner_opt = optimizer
216
        self.meta_optimizers_white_list = ['AMPOptimizer']
217
        self.meta_optimizers_black_list = [
218 219
            "GraphExecutionOptimizer",
            "LocalSGDOptimizer",
220 221 222 223 224 225 226 227 228 229
        ]
        self.snapshot_key = '@SNAPSHOT'

    def _can_apply(self):
        if not self.role_maker._is_collective:
            return False

        if not self.user_defined_strategy.adaptive_localsgd:
            return False

230
        if self.role_maker._worker_num() <= 1:
231 232
            return False

233
        return (
234
            isinstance(self.inner_opt, paddle.optimizer.Momentum)
235 236
            or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum)
            or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD)
237
            or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD)
238
        )
239 240 241 242 243 244 245 246 247

    def _disable_strategy(self, dist_strategy):
        dist_strategy.adaptive_localsgd = False
        dist_strategy.adaptive_localsgd_configs = {}

    def _enable_strategy(self, dist_strategy, context):
        dist_strategy.adaptive_localsgd = True
        dist_strategy.adaptive_localsgd_configs = {
            "init_k_steps": 1,
248
            "begin_step": 1,
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
        }

    def snapshot_name(self, param_name):
        return param_name + self.snapshot_key

    def create_snapshot_vars(self, program):
        block = program.global_block()

        non_dist_params = []
        for param in block.iter_parameters():
            if not param.is_distributed:
                non_dist_params.append(param)

        p2s = []
        for param in non_dist_params:
264 265 266 267 268 269 270
            snapshot = block.create_var(
                name=self.snapshot_name(param.name),
                shape=param.shape,
                persistable=True,
                stop_gradient=True,
                dtype=param.dtype,
            )
271 272 273 274 275 276
            p2s.append([param, snapshot])
        return p2s

    def init_snapshot_vars(self, startup_program, param2snapshot):
        with program_guard(startup_program):
            for param, snapshot in param2snapshot:
277
                paddle.assign(param, snapshot)
278 279

    def _generate_avg_loss(self, program_block, loss, avg_loss):
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
        program_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': [loss]},
            outputs={'Out': [avg_loss]},
            attrs={
                'ring_id': 0,
                OP_ROLE_KEY: OpRole.Optimize,
                'use_calc_stream': True,
            },
        )
        program_block.append_op(
            type='c_sync_calc_stream',
            inputs={'X': [avg_loss]},
            outputs={'Out': [avg_loss]},
            attrs={OP_ROLE_KEY: OpRole.Optimize},
        )

        program_block.append_op(
            type='scale',
            inputs={'X': [avg_loss]},
            outputs={'Out': [avg_loss]},
            attrs={
                'scale': 1.0 / self.role_maker._worker_num(),
                OP_ROLE_KEY: OpRole.Optimize,
            },
        )

    def minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
        minimized = self.inner_opt.minimize(
            loss, startup_program=startup_program
        )
313 314

        init_k_steps = self.user_defined_strategy.adaptive_localsgd_configs[
315 316
            'init_k_steps'
        ]
317
        begin_step_value = self.user_defined_strategy.adaptive_localsgd_configs[
318 319
            'begin_step'
        ]
320 321 322 323 324 325 326 327 328 329 330 331 332

        if startup_program is None:
            startup_program = default_startup_program()
        main_block = loss.block

        self.nrings = 2
        collective_helper = CollectiveHelper(self.role_maker, self.nrings)
        collective_helper.update_startup_program(startup_program)
        p2s = self.create_snapshot_vars(startup_program)
        self.init_snapshot_vars(startup_program, p2s)

        p2s = self.create_snapshot_vars(main_block.program)
        with program_guard(main_block.program, startup_program):
333
            step = paddle.fluid.layers.autoincreased_step_counter(begin=1)
334

335
            k_steps = paddle.static.create_global_var(
336 337 338 339 340 341 342
                name="k_steps",
                shape=[1],
                value=int(init_k_steps),
                dtype='int64',
                persistable=True,
            )

343
            begin_step = paddle.static.create_global_var(
344 345 346 347 348 349 350
                name="begin_step",
                shape=[1],
                value=int(begin_step_value),
                dtype='int64',
                persistable=True,
            )

351
            last_step = paddle.static.create_global_var(
352 353 354 355 356 357 358
                name="last_step",
                shape=[1],
                value=int(0),
                dtype='int64',
                persistable=True,
            )

359
            avg_loss = paddle.static.create_global_var(
360 361 362 363 364 365 366
                name="avg_loss",
                shape=[1],
                value=float(0),
                dtype=loss.dtype,
                persistable=True,
            )

367
            lr_0 = paddle.static.create_global_var(
368 369 370 371 372 373 374
                name="lr_0",
                shape=[1],
                value=float(0),
                dtype='float32',
                persistable=True,
            )

375
            loss_0 = paddle.static.create_global_var(
376 377 378 379 380 381
                name="loss_0",
                shape=[1],
                value=float(0),
                dtype='float32',
                persistable=True,
            )
382 383 384 385 386

            global_lr = self.inner_opt._global_learning_rate()

            def initialize():
                self._generate_avg_loss(main_block, loss, avg_loss)
387 388
                paddle.assign(avg_loss, loss_0)
                paddle.assign(global_lr, lr_0)
389

390
            paddle.static.nn.cond(step == 1, initialize)
391 392 393 394 395

            def communicate():
                sub_block = default_main_program().current_block()
                ring_id = -1
                for param, snapshot in p2s:
396 397 398 399 400 401 402 403 404 405 406 407
                    sub_block.append_op(
                        type='elementwise_sub',
                        inputs={'X': [snapshot], 'Y': [param]},
                        outputs={'Out': [param]},
                        attrs={OP_ROLE_KEY: OpRole.Optimize},
                    )
                    sub_block.append_op(
                        type='c_sync_calc_stream',
                        inputs={'X': param},
                        outputs={'Out': param},
                        attrs={OP_ROLE_KEY: OpRole.Optimize},
                    )
408
                    ring_id = (ring_id + 1) % self.nrings
409 410 411 412 413 414 415 416 417
                    sub_block.append_op(
                        type='c_allreduce_sum',
                        inputs={'X': [param]},
                        outputs={'Out': [param]},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Optimize,
                        },
                    )
418 419

                for ring_id in range(self.nrings):
420 421 422 423 424 425 426 427 428
                    sub_block.append_op(
                        type='c_sync_comm_stream',
                        inputs={'X': param},
                        outputs={'Out': param},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Optimize,
                        },
                    )
429 430

                for param, snapshot in p2s:
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
                    sub_block.append_op(
                        type='scale',
                        inputs={'X': [param]},
                        outputs={'Out': [param]},
                        attrs={
                            'scale': 1.0 / self.role_maker._worker_num(),
                            OP_ROLE_KEY: OpRole.Optimize,
                        },
                    )
                    sub_block.append_op(
                        type='elementwise_sub',
                        inputs={'X': [snapshot], 'Y': [param]},
                        outputs={'Out': [param]},
                        attrs={OP_ROLE_KEY: OpRole.Optimize},
                    )
                    sub_block.append_op(
                        type='assign',
                        inputs={'X': [param]},
                        outputs={'Out': [snapshot]},
                        attrs={OP_ROLE_KEY: OpRole.Optimize},
                    )
452
                paddle.assign(step, last_step)
453 454 455 456

            def communicate_avg_loss():
                communicate()
                self._generate_avg_loss(main_block, loss, avg_loss)
457 458

                next_local_steps = paddle.cast(
459 460
                    paddle.ceil(
                        paddle.sqrt(
461 462 463 464 465 466 467 468
                            lr_0
                            * avg_loss
                            / (global_lr * loss_0)
                            * float(init_k_steps)
                        )
                    ),
                    dtype='int64',
                )
469 470
                max_local_steps = paddle.full(
                    shape=[1], dtype='int64', fill_value=16
471
                )
472 473
                min_local_steps = paddle.full(
                    shape=[1], dtype='int64', fill_value=1
474
                )
475
                next_local_steps = paddle.minimum(
476 477
                    next_local_steps, max_local_steps
                )
H
HongyuJia 已提交
478
                next_local_steps = paddle.maximum(
479 480
                    next_local_steps, min_local_steps
                )
481
                paddle.assign(next_local_steps, k_steps)
482 483

            def begin_localsgd():
484 485 486
                paddle.static.nn.cond(
                    step - last_step == k_steps, communicate_avg_loss
                )
487

488 489 490
            paddle.static.nn.cond(
                step > begin_step, begin_localsgd, communicate
            )
491 492

        return minimized