auto_parallel_gradient_merge.py 15.7 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.framework import core
17
from paddle.fluid import layers
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
from paddle.distributed.fleet.meta_optimizers.common import (
    OpRole,
    OP_ROLE_KEY,
    OP_ROLE_VAR_KEY,
)
from paddle.distributed.auto_parallel.utils import (
    set_var_dist_attr,
    is_optimize_op,
    is_backward_op,
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
from paddle.distributed.auto_parallel.process_group import (
    get_world_process_group,
)
from paddle.distributed.auto_parallel.operators.common import (
    is_data_parallel_reduce_op,
    is_data_parallel_scale_op,
)

37
from .pass_base import PassBase, PassType, register_pass
38

39
world_process_group = get_world_process_group()
40 41


42 43 44 45 46 47 48
def is_gradient_clip_op(op_desc):
    return op_desc.has_attr("op_namescope") and op_desc.attr(
        "op_namescope"
    ).startswith("/gradient_clip")


def _remove_and_get_ops(main_program, dist_context):
49 50 51 52 53 54 55
    # 1 create tmp block
    # 2 mv optimizer op from global program to tmp block
    # 3 del the op from dist_context
    main_block = main_program.global_block()
    temp_block = main_program._create_block()
    removed_op_idx = []
    optimize_ops_desc = []
56
    allreduce_sum_desc = []
57
    for idx, op in enumerate(main_block.ops):
58
        # append optimizer op to tmp block
59
        if is_optimize_op(op):
60 61 62 63
            new_op_desc = temp_block.desc.append_op()
            new_op_desc.copy_from(op.desc)
            optimize_ops_desc.append(new_op_desc)
            removed_op_idx.append(idx)
64 65 66 67 68 69 70 71 72 73
            dist_context.del_dist_op_for_program(op)

        # append allreduce_op and scale_op to tmp block
        if is_backward_op(op):
            if is_data_parallel_reduce_op(op) or is_data_parallel_scale_op(op):
                assert len(op.desc.output_arg_names()) == 1
                new_op_desc = temp_block.desc.append_op()
                new_op_desc.copy_from(op.desc)
                allreduce_sum_desc.append(new_op_desc)
                removed_op_idx.append(idx)
74 75 76
                dist_context.del_dist_op_for_program(op)

    for idx in removed_op_idx[::-1]:
77 78
        main_block._remove_op(idx, sync=False)
    main_block._sync_with_cpp()
79

80
    return optimize_ops_desc, allreduce_sum_desc
81 82


83
def _create_gm_cond_var(main_program, k_steps, dist_context):
84 85
    main_block = main_program.global_block()
    # Add const var
86 87 88 89 90 91 92 93
    k_step_var = layers.create_global_var(
        name="gradient_merge_k",
        shape=[1],
        value=int(k_steps),
        dtype='int32',
        persistable=True,
        force_cpu=True,
    )
94
    set_var_dist_attr(dist_context, k_step_var, [-1], world_process_group.ranks)
95

96 97 98 99 100 101 102 103
    zero_var = layers.create_global_var(
        name="gradient_merge_zero",
        shape=[1],
        value=int(0),
        dtype='int32',
        persistable=True,
        force_cpu=True,
    )
104
    set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks)
105 106

    # Add step var & cond var
107 108 109 110 111 112 113 114
    step_var = layers.create_global_var(
        name="gradient_merge_step",
        shape=[1],
        value=int(0),
        dtype='int32',
        persistable=True,
        force_cpu=True,
    )
115
    set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks)
116

117 118 119
    cond_var = main_block.create_var(
        name="gradient_merge_cond", shape=[1], dtype='bool'
    )
120
    set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)
121

122
    with paddle.static.device_guard("cpu"):
123
        # step_var += 1
124 125 126 127 128 129
        increment_op = main_block.append_op(
            type='increment',
            inputs={'X': [step_var]},
            outputs={'Out': [step_var]},
            attrs={'step': float(1.0), OP_ROLE_KEY: OpRole.Backward},
        )
130
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
131 132
            increment_op, world_process_group.ranks, [-1], dist_context
        )
133
        # step_var %= k_step
134 135 136 137 138 139 140 141 142 143
        elementwise_mod_op = main_block.append_op(
            type='elementwise_mod',
            inputs={'X': step_var, 'Y': k_step_var},
            outputs={'Out': step_var},
            attrs={
                'axis': -1,
                'use_mkldnn': False,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
144
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
145 146
            elementwise_mod_op, world_process_group.ranks, [-1], dist_context
        )
147
        # cond_var = (step_var == 0)
148 149 150 151 152 153
        equal_op = main_block.append_op(
            type='equal',
            inputs={'X': step_var, 'Y': zero_var},
            outputs={'Out': cond_var},
            attrs={OP_ROLE_KEY: OpRole.Backward},
        )
154
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
155 156
            equal_op, world_process_group.ranks, [-1], dist_context
        )
157 158 159 160 161

    return cond_var


def _append_gradient_merge_backward_op(
162 163 164 165 166 167 168
    main_program,
    startup_program,
    params_grads,
    master_grad,
    dist_context,
):

169 170 171 172 173 174 175 176 177
    main_block = main_program.global_block()
    startup_block = startup_program.global_block()

    # step1: remove grad.op's op_role_var
    for param, grad in params_grads:
        assert (
            param.type != core.VarDesc.VarType.SELECTED_ROWS
        ), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"

178 179 180
    # {grad.name: gradient_merge_var.name} to rename opt inputs
    grad_to_gradient_merge = {}
    # {param: gradient_merge_var} to insert scale op and fill_constant op
181 182 183 184 185
    new_params_to_grads = []
    # step2: create gradient_merge var and init with 0
    for param, grad in params_grads:
        param_name = param.name
        param_var = main_block.var(param_name)
186
        assert param_var is not None
187

188 189 190
        dst_dtype = (
            core.VarDesc.VarType.FP32 if master_grad else param_var.dtype
        )
191

192
        # 2.1 crate param@GRAD@MERGE var in startup_block
193
        startup_gradient_merge_var = startup_block.create_var(
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
            name=param_name + "@GRAD@MERGED",
            shape=param_var.shape,
            dtype=dst_dtype,
            persistable=True,
        )
        startup_block.append_op(
            type="fill_constant",
            outputs={"Out": startup_gradient_merge_var},
            attrs={
                "shape": param_var.shape,
                "dtype": dst_dtype,
                "value": float(0),
            },
        )

        # 2.2 crate param@GRAD@MERGE var in main_block
        ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var)
        assert ref_dist_attr is not None
        gradient_merge_var = main_block.create_var(
            name=param_name + "@GRAD@MERGED",
214
            shape=param_var.shape,
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
            dtype=dst_dtype,
            persistable=True,
        )
        ref_process_mesh = ref_dist_attr.process_mesh
        ref_dims_mapping = ref_dist_attr.dims_mapping
        set_var_dist_attr(
            dist_context, gradient_merge_var, ref_dims_mapping, ref_process_mesh
        )

        # 2.3 grad_merge += grad
        grad_name = grad.name
        if grad.dtype != dst_dtype:
            cast_grad_name = grad_name + "@TMP"
            cast_grad_var = main_block.create_var(
                name=cast_grad_name,
                shape=grad.shape,
                dtype=dst_dtype,
                persistable=False,
                stop_gradient=grad.stop_gradient,
            )
            set_var_dist_attr(
                dist_context, cast_grad_var, ref_dims_mapping, ref_process_mesh
            )
            cast_op = main_block.append_op(
                type="cast",
                inputs={"X": grad},
                outputs={"Out": cast_grad_var},
                attrs={
                    "in_dtype": grad.dtype,
                    "out_dtype": dst_dtype,
                    OP_ROLE_KEY: OpRole.Backward,
                },
            )
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                cast_op, ref_process_mesh, ref_dims_mapping, dist_context
            )
            grad = cast_grad_var

        new_grad_op = main_block.append_op(
            type="elementwise_add",
            inputs={'X': grad, 'Y': gradient_merge_var},
            outputs={'Out': gradient_merge_var},
            attrs={
                'axis': -1,
                'use_mkldnn': False,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
263
        new_params_to_grads.append([param, gradient_merge_var])
264
        grad_to_gradient_merge[grad_name] = gradient_merge_var.name
265
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
266 267 268
            new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context
        )

269
    return new_params_to_grads, grad_to_gradient_merge
270 271


272 273 274 275
def _rename_arg_names(op_desc, var_name_dict):
    for input_name in op_desc.input_arg_names():
        if input_name in var_name_dict:
            op_desc._rename_input(input_name, var_name_dict[input_name])
276

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
    for output_name in op_desc.output_arg_names():
        if output_name in var_name_dict:
            op_desc._rename_output(output_name, var_name_dict[output_name])


def _create_cond_block_and_update_optimizer(
    main_program,
    cond_var,
    params_grads,
    new_params_to_grads,
    grad_to_gradient_merge,
    optimize_ops_desc,
    allreduce_sum_desc,
    k_steps,
    avg,
    master_grad,
):
294 295 296 297 298 299
    def true_apply_gradient():
        cur_block_idx = main_program.current_block_idx
        cur_block = main_program.current_block()

        # cur_block's forward_block & backward_block is itself
        cur_block._set_forward_block_idx(cur_block_idx)
300 301 302 303 304 305 306 307 308 309 310 311 312 313

        # record grads_name to insert c_allreduce_sum op
        grads_name = [grad.name for _, grad in params_grads]
        # append c_allreduce_sum ops and scale ops
        for op_desc in allreduce_sum_desc:
            outputs_name = op_desc.output_arg_names()
            assert len(outputs_name) == 1
            if outputs_name[0] in grads_name:
                new_op_desc = cur_block.desc.append_op()
                new_op_desc.copy_from(op_desc)
                _rename_arg_names(new_op_desc, grad_to_gradient_merge)
                new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize)
        cur_block._sync_with_cpp()

314
        if avg:
315
            for _, new_grad in new_params_to_grads:
316
                # grad /= k_steps
317 318 319 320 321 322 323 324 325 326
                cur_block.append_op(
                    type='scale',
                    inputs={'X': new_grad},
                    outputs={'Out': new_grad},
                    attrs={
                        'scale': 1.0 / k_steps,
                        'bias': 0.0,
                        'bias_after_scale': False,
                    },
                )
327
                new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
328

329
        cast_name_dict = {}
330 331
        # append optimizer ops
        for op_desc in optimize_ops_desc:
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
            if master_grad and is_gradient_clip_op(op_desc):
                if op_desc.type() == "cast":
                    if (
                        op_desc.attr('out_dtype') in [4, 22]
                        and op_desc.attr('in_dtype') == 5
                    ):
                        cast_name_dict[
                            op_desc.output_arg_names()[0]
                        ] = op_desc.input_arg_names()[0]
                    elif (
                        op_desc.attr('in_dtype') in [4, 22]
                        and op_desc.attr('out_dtype') == 5
                    ):
                        cast_name_dict[
                            op_desc.output_arg_names()[0]
                        ] = op_desc.input_arg_names()[0]
                    continue

                for out_name in op_desc.output_arg_names():
                    out_var = cur_block._var_recursive(out_name)
                    out_var.desc.set_dtype(core.VarDesc.VarType.FP32)

                _rename_arg_names(op_desc, cast_name_dict)

356 357 358
            new_op_desc = cur_block.desc.append_op()
            new_op_desc.copy_from(op_desc)

359 360
            # update input/output
            _rename_arg_names(new_op_desc, grad_to_gradient_merge)
361 362

            # remove op_role_var
363 364
            if new_op_desc.has_attr(OP_ROLE_VAR_KEY):
                new_op_desc.remove_attr(OP_ROLE_VAR_KEY)
365 366

            # op's update Grad
367
            if core.grad_var_suffix() in new_op_desc.input_arg_names():
368 369
                grad_value = new_op_desc.input("Grad")[0]
                # TODO FIXME(xym) support fp16
370
                grad_merge_value = grad_value + '@MERGED'
371 372 373 374 375
                new_op_desc.set_input("Grad", [grad_merge_value])

        cur_block._sync_with_cpp()

        # clear gradient_merge_vars
376 377 378 379 380 381 382 383
        for _, new_grad in new_params_to_grads:
            layers.fill_constant(
                shape=new_grad.shape,
                dtype=new_grad.dtype,
                value=0.0,
                out=new_grad,
            )
            new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
384 385

    layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None)
386
    cond_op = main_program.global_block().ops[-1]
387
    cond_op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
388 389


390 391 392 393 394 395 396 397 398 399 400 401 402
def parse_program(
    main_program,
    startup_program,
    params_grads,
    k_steps,
    avg,
    master_grad,
    dist_context,
):
    # 1 remove optimizer_op, allreduce_sum_op and scale_op from main_program
    optimize_ops_desc, allreduce_sum_desc = _remove_and_get_ops(
        main_program, dist_context
    )
403 404 405 406

    # back to block 0
    main_program._rollback()

407
    # 2 append gradient merge backward op to main_program
408 409 410 411 412 413
    (
        new_params_to_grads,
        grad_to_gradient_merge,
    ) = _append_gradient_merge_backward_op(
        main_program, startup_program, params_grads, master_grad, dist_context
    )
414 415

    # 3 create gradient_merge_cond
416
    cond_var = _create_gm_cond_var(main_program, k_steps, dist_context)
417 418

    # 4 create ConditionalBlock and append gradient merge optimizer ops
419 420 421 422 423 424 425 426 427 428 429 430
    _create_cond_block_and_update_optimizer(
        main_program,
        cond_var,
        params_grads,
        new_params_to_grads,
        grad_to_gradient_merge,
        optimize_ops_desc,
        allreduce_sum_desc,
        k_steps,
        avg,
        master_grad,
    )
431 432 433 434 435 436 437 438


@register_pass("auto_parallel_gradient_merge_pass")
class GradientMergePass(PassBase):
    def __init__(self):
        super(GradientMergePass, self).__init__()
        self.set_attr("k_steps", -1)
        self.set_attr("avg", True)
439
        self.set_attr("master_grad", False)
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454

    def _check_self(self):
        if self.get_attr("k_steps") < 1:
            return False
        return True

    def _check_conflict(self, other_pass):
        return True

    def _type(self):
        return PassType.COMM_OPT

    def _apply_single_impl(self, main_program, startup_program, context):
        k_steps = self.get_attr("k_steps", -1)
        avg = self.get_attr("avg", False)
455
        master_grad = self.get_attr("master_grad", False)
456 457
        dist_context = self.get_attr("dist_context")
        params_grads = self.get_attr("params_grads")
458 459
        # TODO(zyl): make master_grad configurable
        master_grad = True
460
        with paddle.static.program_guard(main_program, startup_program):
461 462 463 464 465 466 467 468 469
            parse_program(
                main_program,
                startup_program,
                params_grads,
                k_steps,
                avg,
                master_grad,
                dist_context,
            )
470 471

        main_program._sync_with_cpp()