partitioner.py 19.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
16

17
import paddle.fluid as fluid
18
from paddle.distributed.auto_parallel.dist_context import DistributedContext
19 20 21
from paddle.distributed.auto_parallel.operators.common import (
    get_distributed_operator_impl_container,
)
22 23 24
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program

25
from .dist_attribute import OperatorDistributedAttribute
J
JZ-LIANG 已提交
26
from .operators.common import BACKWARD_ONLY_DIST_OPS
Z
zhaoyingli 已提交
27
from .utils import (
28
    __no_shape_var_type__,
Z
zhaoyingli 已提交
29 30 31 32 33
    is_backward_op,
    is_forward_op,
    is_loss_op,
    is_optimize_op,
)
34

35
__varname_not_in_block__ = ["lod_tensor_blocking_queue"]
36 37


38
class Partitioner:
39 40 41 42
    """
    warning:: Partitioner is experimental and subject to change.

    Partitioner convert a program into another program.
43
    Given a serial program which has been auto completed with shard annotation, the Partitioner
44 45 46 47 48 49 50 51
    convert the serial program into a "distributed" program. The Partitioner will  modify the serial
    program in following two ways, which is also the major difference between serial and distributed program:
        1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
        2. partition var: if a var is sharded, modify the shape of var according to its shard annotation

    Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user.
    """

52
    def __init__(self, dist_context, rank_id=0):
53 54
        """
        Args:
55
            dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
56 57
            rank_id (int): global rank id to which the partitioned distributed program belong.
        """
58
        if not isinstance(dist_context, DistributedContext):
59
            raise TypeError(
60 61 62
                "dist_context be paddle.fluid.DistributedContext, got %s here"
                % type(dist_context)
            )
63

64
        self._dist_context = dist_context
65 66 67 68
        self._rank_id = rank_id
        self._serial2dist_varname_mapping = {}
        self._dist_varname_suffix = ""

69 70 71
    def partition(
        self, serial_main_program, serial_startup_program, params_grads
    ):
72
        if not isinstance(serial_main_program, (Program)):
73
            raise TypeError(
74 75 76
                "main_program be paddle.fluid.framework.program, got %s here"
                % type(serial_main_program)
            )
77 78

        # check if shard annotated serial program valid
79
        if not self._is_valid_annotated_program(serial_main_program):
80
            raise RuntimeError(
81 82
                "Not all vars or ops are annotated in main program !"
            )
83

84 85
        # init distop helper
        dist_op_context = self._dist_context.dist_op_context
86 87
        dist_op_context.varname_mapping = self._serial2dist_varname_mapping
        dist_op_context.rank_id = self._rank_id
88

89
        # partition startup program
90
        if serial_startup_program is None:
91 92 93
            partitioned_startup_prog = None
        else:
            partitioned_startup_prog = self.partition_startup_program(
94 95
                serial_main_program, serial_startup_program
            )
96
        dist_op_context.dst_startup_program = partitioned_startup_prog
97

98
        # partition main program
99 100 101 102
        (
            partitioned_main_prog,
            partitioned_params_grads,
        ) = self.partition_main_program(serial_main_program, params_grads)
103

104 105 106 107 108
        return (
            partitioned_main_prog,
            partitioned_startup_prog,
            partitioned_params_grads,
        )
109

110 111 112
    def partition_startup_program(
        self, serial_main_program, serial_startup_program
    ):
113

114 115
        if not isinstance(serial_startup_program, (Program)):
            raise TypeError(
116 117 118
                "dist_context be paddle.fluid.framework.program, got %s here"
                % type(serial_startup_program)
            )
119

120 121 122
        partitioned_startup_prog = fluid.Program()
        ref_block = serial_main_program.global_block()
        target_block = partitioned_startup_prog.global_block()
J
JZ-LIANG 已提交
123
        var2shape = {}
124
        temp_varname_map = {}
125

126 127
        # tensors
        for var in serial_startup_program.list_vars():
J
JZ-LIANG 已提交
128 129 130
            assert var.persistable
            new_name = var.name + self._dist_varname_suffix
            temp_varname_map[var.name] = new_name
131 132 133
            target_shape = _partition_var(
                self._dist_context, ref_block, target_block, var.name, new_name
            )
J
JZ-LIANG 已提交
134
            var2shape[new_name] = target_shape
135 136 137 138 139

        # ops
        for op in serial_startup_program.global_block().ops:
            # TODO if var not belong to this rank, should be filtered
            output_vars = op.desc.output_arg_names()
140 141 142 143 144 145 146 147 148 149
            assert (
                len(output_vars) == 1
            ), "initializer should output only ONE variable, but got [{}]".format(
                str(op.desc)
            )
            assert (
                temp_varname_map[output_vars[0]] in var2shape
            ), "try to initialize [{}] which is not a persistable var".format(
                output_vars[0]
            )
150 151
            new_op_desc = target_block.desc.append_op()
            new_op_desc.copy_from(op.desc)
152 153 154 155 156 157
            new_op_desc._rename_output(
                output_vars[0], temp_varname_map[output_vars[0]]
            )
            new_op_desc._set_attr(
                "shape", var2shape[temp_varname_map[output_vars[0]]]
            )
158 159 160 161 162 163 164
            target_block._sync_with_cpp()

            # set distribute atrribute
            new_op = target_block.ops[-1]
            assert new_op.type == new_op_desc.type()
            assert new_op.desc == new_op_desc
            output_var = target_block.var(output_vars[0])
165 166 167
            output_var_attr = (
                self._dist_context.get_tensor_dist_attr_for_program(output_var)
            )
168 169
            op_attr = OperatorDistributedAttribute()
            op_attr.process_mesh = output_var_attr.process_mesh
170 171 172 173 174 175
            op_attr.set_output_dims_mapping(
                output_var.name, output_var_attr.dims_mapping
            )
            op_attr.set_input_dims_mapping(
                output_var.name, output_var_attr.dims_mapping
            )
176 177 178 179 180
            self._dist_context.set_op_dist_attr_for_program(new_op, op_attr)

        return partitioned_startup_prog

    def partition_main_program(self, serial_main_program, params_and_grads):
181 182 183 184 185 186
        """
        1. partition variables
        2. replace local op with corresponding dist op
        """

        partitioned_main_prog = fluid.Program()
187 188 189 190 191 192 193 194 195 196
        dist_op_context = self._dist_context.dist_op_context
        dist_op_context.dst_main_program = partitioned_main_prog

        for idx in range(self._dist_context.block_state.nblock):
            ref_block = serial_main_program.blocks[idx]

            if idx == 0:
                target_block = partitioned_main_prog.blocks[0]
            else:
                target_block = partitioned_main_prog._create_block(
197 198
                    parent_idx=ref_block.parent_idx
                )
199 200 201 202 203 204 205
                assert ref_block.idx == target_block.idx
                target_block._set_forward_block_idx(ref_block.forward_block_idx)
            dist_op_context.work_block = target_block
            self.partition_block(ref_block, target_block)

        partitioned_main_prog.current_block_idx = 0

206 207 208 209 210 211 212
        # should reconnect the block_attr ptr to the correct block
        for block_id in range(self._dist_context.block_state.nblock):
            block = partitioned_main_prog.block(block_id)
            for op in block.ops:
                for attr_name in op.all_attrs():
                    if op.attr_type(attr_name) == core.AttrType.BLOCK:
                        relative_id = op._block_attr_id(attr_name)
213 214 215
                        op._set_attr(
                            attr_name, partitioned_main_prog.block(relative_id)
                        )
216

217 218 219 220 221 222 223 224
        partitioned_params_and_grads = []
        for p, g in params_and_grads:
            assert p.name in self._serial2dist_varname_mapping
            dist_p = self._get_dist_var_by_serial_var(p, partitioned_main_prog)
            if g is None:
                dist_g = None
            else:
                assert g.name in self._serial2dist_varname_mapping
225
                dist_g = self._get_dist_var_by_serial_var(
226 227
                    g, partitioned_main_prog
                )
228 229 230 231 232 233 234 235
            partitioned_params_and_grads.append((dist_p, dist_g))

        return partitioned_main_prog, partitioned_params_and_grads

    def partition_block(self, ref_block, target_block):

        dist_op_context = self._dist_context.dist_op_context
        serial_ops = ref_block.ops
236

237 238 239 240 241 242 243 244 245
        last_fwd_op_idx = -1
        for idx, op in enumerate(ref_block.ops):
            if is_loss_op(op):
                last_fwd_op_idx = idx
                break

        if last_fwd_op_idx == -1:
            last_fwd_op_idx = len(ref_block.ops)

246 247 248
        # init mapping
        forward_op_id2forward_op = {}
        for idx in range(len(serial_ops)):
249
            if idx <= last_fwd_op_idx:
250
                forward_op_id2forward_op[
251 252
                    serial_ops[idx].desc.original_id()
                ] = serial_ops[idx]
253

254
        # partiiton
Z
zhaoyingli 已提交
255
        appended_grad_times = 0
256 257
        for idx, op in enumerate(serial_ops):

Z
zhaoyingli 已提交
258
            op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
259 260 261 262
            if is_backward_op(op) and (
                is_forward_op(serial_ops[idx - 1])
                or is_loss_op(serial_ops[idx - 1])
            ):
Z
zhaoyingli 已提交
263 264
                if not op_dist_attr.is_recompute:
                    appended_grad_times += 1
265 266 267

            # partititon input variables
            for serial_input_varname in op.desc.input_arg_names():
268 269 270 271 272 273 274
                if (
                    serial_input_varname
                    not in self._serial2dist_varname_mapping
                ):
                    new_varname = (
                        serial_input_varname + self._dist_varname_suffix
                    )
275
                    if ref_block.has_var(serial_input_varname):
276 277 278 279 280 281 282
                        _partition_var(
                            self._dist_context,
                            ref_block,
                            target_block,
                            serial_input_varname,
                            new_varname,
                        )
283
                    else:
284
                        for varname_not_in_block in __varname_not_in_block__:
285 286 287
                            assert (
                                varname_not_in_block in serial_input_varname
                            ), "{} is not found".format(serial_input_varname)
288 289

                    self._serial2dist_varname_mapping[
290 291
                        serial_input_varname
                    ] = new_varname
292 293 294

            # partition output vars
            for serial_output_varname in op.desc.output_arg_names():
295 296 297 298 299 300 301 302 303 304 305 306 307 308
                if (
                    serial_output_varname
                    not in self._serial2dist_varname_mapping
                ):
                    new_varname = (
                        serial_output_varname + self._dist_varname_suffix
                    )
                    _partition_var(
                        self._dist_context,
                        ref_block,
                        target_block,
                        serial_output_varname,
                        new_varname,
                    )
309
                    self._serial2dist_varname_mapping[
310 311
                        serial_output_varname
                    ] = new_varname
312 313

            # partition op
314
            if is_forward_op(op) or op_dist_attr.is_recompute:
315 316
                kinputs, koutputs = dist_op_context.prepare_context(op)
                dist_op_forward_impl = _get_dist_op_forward_implement(
317 318 319 320 321
                    op, self._dist_context
                )
                dist_op_forward_impl.forward(
                    self._dist_context, **kinputs, **koutputs
                )
322 323 324 325

            elif is_backward_op(op):
                kinputs, koutputs = dist_op_context.prepare_context(op)
                dist_op_backward_impl = _get_dist_op_backward_implement(
326 327 328 329 330 331 332
                    op, self._dist_context, forward_op_id2forward_op
                )
                grad_var_to_var = (
                    self._dist_context.dist_op_context.grad_var_to_var[
                        appended_grad_times
                    ]
                )
333
                dist_op_backward_impl.backward(
334 335 336 337 338
                    self._dist_context,
                    **kinputs,
                    **koutputs,
                    **{"grad_var_to_var": grad_var_to_var}
                )
339
            elif is_optimize_op(op):
340
                # NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
341
                kinputs, koutputs = dist_op_context.prepare_context(op)
342
                dist_op_opt_impl = _get_dist_op_backward_implement(
343 344 345 346 347 348 349 350
                    op, self._dist_context, forward_op_id2forward_op
                )
                dist_op_opt_impl.backward(
                    self._dist_context,
                    **kinputs,
                    **koutputs,
                    **{"grad_var_to_var": {}}
                )
351
            else:
352
                raise NotImplementedError(
353 354 355 356
                    "partitioner only support forward and backward, optimize ops, but got {}".format(
                        str(op)
                    )
                )
357

358 359 360 361 362 363
    def _is_valid_annotated_program(self, program):

        # TODO (ZJ-LIANG) should check all block
        ops = program.global_block().ops
        vars_ = program.list_vars()
        op_dist_attrs = [
364
            self._dist_context.get_op_dist_attr_for_program(op) for op in ops
365 366
        ]
        var_dist_attrs = [
367
            self._dist_context.get_tensor_dist_attr_for_program(var)
368
            for var in vars_
Z
zhaoyingli 已提交
369
            if (var.type not in __no_shape_var_type__)
370 371
        ]

372 373 374 375 376 377
        all_ops_annotated = all(
            dist_attr is not None for dist_attr in op_dist_attrs
        )
        all_vars_annotated = all(
            dist_attr is not None for dist_attr in var_dist_attrs
        )
378 379 380

        return all_ops_annotated and all_vars_annotated

381 382 383 384 385 386 387 388
    def _get_dist_var_by_serial_var(self, serial_var, partitioned_main_prog):

        block_idx = serial_var.block.idx
        target_block = partitioned_main_prog.blocks[block_idx]
        dist_var_name = self._serial2dist_varname_mapping[serial_var.name]
        assert target_block.has_var(dist_var_name)
        return target_block.var(dist_var_name)

389 390 391 392

def _get_dist_shape(var, dist_attr):

    var_shape = var.shape
393 394
    mapping = dist_attr.dims_mapping
    mesh = dist_attr.process_mesh.topology
395 396 397
    if mapping == []:
        return var_shape

398 399 400
    assert len(var_shape) == len(
        mapping
    ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
401 402
        var_shape, mapping
    )
403 404 405 406 407
    new_shape = []
    for idx in range(len(var_shape)):
        if var_shape[idx] == -1 or mapping[idx] == -1:
            new_shape.append(var_shape[idx])
        else:
408 409 410 411 412
            assert (
                var_shape[idx] % mesh[mapping[idx]] == 0
            ), "un-event partition: var_shape[idx]=[{}], mesh[{}]".format(
                var_shape[idx], mesh[mapping[idx]]
            )
413 414 415 416 417
            new_shape.append(var_shape[idx] // mesh[mapping[idx]])

    return new_shape


418 419 420
def _partition_parameter(
    dist_context, src_var, dst_block, dst_varname, dst_shape
):
421
    # NOTE hack to copied Parameter
422
    # not initialized parameter, need to initialize it
423 424 425 426 427 428 429
    copied_kwargs = {}
    copied_kwargs['trainable'] = src_var.trainable
    copied_kwargs['optimize_attr'] = src_var.optimize_attr
    copied_kwargs['regularizer'] = src_var.regularizer
    copied_kwargs['do_model_average'] = src_var.do_model_average
    copied_kwargs['need_clip'] = src_var.need_clip

430 431 432 433 434 435 436 437 438 439 440 441 442
    param = Parameter(
        block=dst_block,
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer,
        **copied_kwargs
    )
443

444
    return param
445 446


447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
def _partition_intermediate_var(
    dist_context, src_var, dst_block, dst_varname, dst_shape
):
    var = dst_block.create_var(
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        persistable=src_var.persistable,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer,
    )
462

463
    return var
464 465


466 467 468
def _partition_var(
    dist_context, src_block, dst_block, src_varname, dst_varname
):
469 470 471 472 473
    """
    partition include: split + replicate
    """
    src_var = src_block.var(src_varname)

Z
zhaoyingli 已提交
474
    if src_var.type in __no_shape_var_type__:
475
        persist = getattr(src_var, 'persistable', False)
476 477 478 479 480 481
        new_var = dst_block.create_var(
            type=src_var.type,
            name=dst_varname,
            persistable=persist,
            stop_gradient=True,
        )
J
JZ-LIANG 已提交
482
        target_shape = None
483
    else:
484
        dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
485 486 487
        target_shape = _get_dist_shape(src_var, dist_attr)

        if isinstance(src_var, Parameter):
488 489 490
            new_var = _partition_parameter(
                dist_context, src_var, dst_block, dst_varname, target_shape
            )
491
        else:
492 493 494
            new_var = _partition_intermediate_var(
                dist_context, src_var, dst_block, dst_varname, target_shape
            )
495 496

    dist_attr = copy.deepcopy(
497 498
        dist_context.get_tensor_dist_attr_for_program(src_var)
    )
499 500 501
    assert dist_attr is not None
    dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr)

J
JZ-LIANG 已提交
502
    return target_shape
503 504


505 506 507
def _get_dist_op_backward_implement(
    backward_op, dist_context, forward_op_id2forward_op
):
508
    dist_op_context = dist_context.dist_op_context
509 510
    if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
        forward_op_id = dist_op_context.grad_op_id_to_op_id[
511 512
            backward_op.desc.original_id()
        ]
513 514
        forward_op = forward_op_id2forward_op[forward_op_id]
        forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
515 516
            forward_op
        )
517
        dist_op_impl_container = get_distributed_operator_impl_container(
518 519
            forward_op_dist_attr.impl_type
        )
520
        dist_op_impl = dist_op_impl_container.get_impl(
521 522
            forward_op_dist_attr.impl_idx
        )
523
        return dist_op_impl
524

525
    # # NOTE trick for dist ops that only have backward implement
J
JZ-LIANG 已提交
526 527
    if backward_op.type in BACKWARD_ONLY_DIST_OPS:
        op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
528 529
        assert op_dist_attr.impl_idx >= 0
        dist_op_impl = get_distributed_operator_impl_container(
530 531
            op_dist_attr.impl_type
        ).get_impl(op_dist_attr.impl_idx)
532
        return dist_op_impl
J
JZ-LIANG 已提交
533 534 535

    dist_op = get_distributed_operator_impl_container("default")
    return dist_op.get_impl(0)
536 537 538 539


def _get_dist_op_forward_implement(forward_op, dist_context):
    dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
540
    dist_op_impl_container = get_distributed_operator_impl_container(
541 542
        dist_attr.impl_type
    )
543 544
    dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx)
    return dist_op_impl