partitioner.py 19.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
import paddle.fluid as fluid
from paddle.fluid import core
18 19
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program
20 21 22
from paddle.distributed.auto_parallel.operators.common import (
    get_distributed_operator_impl_container,
)
23
from paddle.distributed.auto_parallel.dist_context import DistributedContext
24
from .dist_attribute import OperatorDistributedAttribute
J
JZ-LIANG 已提交
25
from .operators.common import BACKWARD_ONLY_DIST_OPS
Z
zhaoyingli 已提交
26 27 28 29 30 31 32
from .utils import (
    is_backward_op,
    is_forward_op,
    is_loss_op,
    is_optimize_op,
    __no_shape_var_type__,
)
33

34
__varname_not_in_block__ = ["lod_tensor_blocking_queue"]
35 36


37
class Partitioner:
38 39 40 41
    """
    warning:: Partitioner is experimental and subject to change.

    Partitioner convert a program into another program.
42
    Given a serial program which has been auto completed with shard annotation, the Partitioner
43 44 45 46 47 48 49 50
    convert the serial program into a "distributed" program. The Partitioner will  modify the serial
    program in following two ways, which is also the major difference between serial and distributed program:
        1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
        2. partition var: if a var is sharded, modify the shape of var according to its shard annotation

    Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user.
    """

51
    def __init__(self, dist_context, rank_id=0):
52 53
        """
        Args:
54
            dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
55 56
            rank_id (int): global rank id to which the partitioned distributed program belong.
        """
57
        if not isinstance(dist_context, DistributedContext):
58
            raise TypeError(
59 60 61
                "dist_context be paddle.fluid.DistributedContext, got %s here"
                % type(dist_context)
            )
62

63
        self._dist_context = dist_context
64 65 66 67
        self._rank_id = rank_id
        self._serial2dist_varname_mapping = {}
        self._dist_varname_suffix = ""

68 69 70
    def partition(
        self, serial_main_program, serial_startup_program, params_grads
    ):
71
        if not isinstance(serial_main_program, (Program)):
72
            raise TypeError(
73 74 75
                "main_program be paddle.fluid.framework.program, got %s here"
                % type(serial_main_program)
            )
76 77

        # check if shard annotated serial program valid
78
        if not self._is_valid_annotated_program(serial_main_program):
79
            raise RuntimeError(
80 81
                "Not all vars or ops are annotated in main program !"
            )
82

83 84
        # init distop helper
        dist_op_context = self._dist_context.dist_op_context
85 86
        dist_op_context.varname_mapping = self._serial2dist_varname_mapping
        dist_op_context.rank_id = self._rank_id
87

88
        # partition startup program
89
        if serial_startup_program is None:
90 91 92
            partitioned_startup_prog = None
        else:
            partitioned_startup_prog = self.partition_startup_program(
93 94
                serial_main_program, serial_startup_program
            )
95
        dist_op_context.dst_startup_program = partitioned_startup_prog
96

97
        # partition main program
98 99 100 101
        (
            partitioned_main_prog,
            partitioned_params_grads,
        ) = self.partition_main_program(serial_main_program, params_grads)
102

103 104 105 106 107
        return (
            partitioned_main_prog,
            partitioned_startup_prog,
            partitioned_params_grads,
        )
108

109 110 111
    def partition_startup_program(
        self, serial_main_program, serial_startup_program
    ):
112

113 114
        if not isinstance(serial_startup_program, (Program)):
            raise TypeError(
115 116 117
                "dist_context be paddle.fluid.framework.program, got %s here"
                % type(serial_startup_program)
            )
118

119 120 121
        partitioned_startup_prog = fluid.Program()
        ref_block = serial_main_program.global_block()
        target_block = partitioned_startup_prog.global_block()
J
JZ-LIANG 已提交
122
        var2shape = {}
123
        temp_varname_map = {}
124

125 126
        # tensors
        for var in serial_startup_program.list_vars():
J
JZ-LIANG 已提交
127 128 129
            assert var.persistable
            new_name = var.name + self._dist_varname_suffix
            temp_varname_map[var.name] = new_name
130 131 132
            target_shape = _partition_var(
                self._dist_context, ref_block, target_block, var.name, new_name
            )
J
JZ-LIANG 已提交
133
            var2shape[new_name] = target_shape
134 135 136 137 138

        # ops
        for op in serial_startup_program.global_block().ops:
            # TODO if var not belong to this rank, should be filtered
            output_vars = op.desc.output_arg_names()
139 140 141 142 143 144 145 146 147 148
            assert (
                len(output_vars) == 1
            ), "initializer should output only ONE variable, but got [{}]".format(
                str(op.desc)
            )
            assert (
                temp_varname_map[output_vars[0]] in var2shape
            ), "try to initialize [{}] which is not a persistable var".format(
                output_vars[0]
            )
149 150
            new_op_desc = target_block.desc.append_op()
            new_op_desc.copy_from(op.desc)
151 152 153 154 155 156
            new_op_desc._rename_output(
                output_vars[0], temp_varname_map[output_vars[0]]
            )
            new_op_desc._set_attr(
                "shape", var2shape[temp_varname_map[output_vars[0]]]
            )
157 158 159 160 161 162 163
            target_block._sync_with_cpp()

            # set distribute atrribute
            new_op = target_block.ops[-1]
            assert new_op.type == new_op_desc.type()
            assert new_op.desc == new_op_desc
            output_var = target_block.var(output_vars[0])
164 165 166
            output_var_attr = (
                self._dist_context.get_tensor_dist_attr_for_program(output_var)
            )
167 168
            op_attr = OperatorDistributedAttribute()
            op_attr.process_mesh = output_var_attr.process_mesh
169 170 171 172 173 174
            op_attr.set_output_dims_mapping(
                output_var.name, output_var_attr.dims_mapping
            )
            op_attr.set_input_dims_mapping(
                output_var.name, output_var_attr.dims_mapping
            )
175 176 177 178 179
            self._dist_context.set_op_dist_attr_for_program(new_op, op_attr)

        return partitioned_startup_prog

    def partition_main_program(self, serial_main_program, params_and_grads):
180 181 182 183 184 185
        """
        1. partition variables
        2. replace local op with corresponding dist op
        """

        partitioned_main_prog = fluid.Program()
186 187 188 189 190 191 192 193 194 195
        dist_op_context = self._dist_context.dist_op_context
        dist_op_context.dst_main_program = partitioned_main_prog

        for idx in range(self._dist_context.block_state.nblock):
            ref_block = serial_main_program.blocks[idx]

            if idx == 0:
                target_block = partitioned_main_prog.blocks[0]
            else:
                target_block = partitioned_main_prog._create_block(
196 197
                    parent_idx=ref_block.parent_idx
                )
198 199 200 201 202 203 204
                assert ref_block.idx == target_block.idx
                target_block._set_forward_block_idx(ref_block.forward_block_idx)
            dist_op_context.work_block = target_block
            self.partition_block(ref_block, target_block)

        partitioned_main_prog.current_block_idx = 0

205 206 207 208 209 210 211
        # should reconnect the block_attr ptr to the correct block
        for block_id in range(self._dist_context.block_state.nblock):
            block = partitioned_main_prog.block(block_id)
            for op in block.ops:
                for attr_name in op.all_attrs():
                    if op.attr_type(attr_name) == core.AttrType.BLOCK:
                        relative_id = op._block_attr_id(attr_name)
212 213 214
                        op._set_attr(
                            attr_name, partitioned_main_prog.block(relative_id)
                        )
215

216 217 218 219 220 221 222 223
        partitioned_params_and_grads = []
        for p, g in params_and_grads:
            assert p.name in self._serial2dist_varname_mapping
            dist_p = self._get_dist_var_by_serial_var(p, partitioned_main_prog)
            if g is None:
                dist_g = None
            else:
                assert g.name in self._serial2dist_varname_mapping
224
                dist_g = self._get_dist_var_by_serial_var(
225 226
                    g, partitioned_main_prog
                )
227 228 229 230 231 232 233 234
            partitioned_params_and_grads.append((dist_p, dist_g))

        return partitioned_main_prog, partitioned_params_and_grads

    def partition_block(self, ref_block, target_block):

        dist_op_context = self._dist_context.dist_op_context
        serial_ops = ref_block.ops
235

236 237 238 239 240 241 242 243 244
        last_fwd_op_idx = -1
        for idx, op in enumerate(ref_block.ops):
            if is_loss_op(op):
                last_fwd_op_idx = idx
                break

        if last_fwd_op_idx == -1:
            last_fwd_op_idx = len(ref_block.ops)

245 246 247
        # init mapping
        forward_op_id2forward_op = {}
        for idx in range(len(serial_ops)):
248
            if idx <= last_fwd_op_idx:
249
                forward_op_id2forward_op[
250 251
                    serial_ops[idx].desc.original_id()
                ] = serial_ops[idx]
252

253
        # partiiton
Z
zhaoyingli 已提交
254
        appended_grad_times = 0
255 256
        for idx, op in enumerate(serial_ops):

Z
zhaoyingli 已提交
257
            op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
258 259 260 261
            if is_backward_op(op) and (
                is_forward_op(serial_ops[idx - 1])
                or is_loss_op(serial_ops[idx - 1])
            ):
Z
zhaoyingli 已提交
262 263
                if not op_dist_attr.is_recompute:
                    appended_grad_times += 1
264 265 266

            # partititon input variables
            for serial_input_varname in op.desc.input_arg_names():
267 268 269 270 271 272 273
                if (
                    serial_input_varname
                    not in self._serial2dist_varname_mapping
                ):
                    new_varname = (
                        serial_input_varname + self._dist_varname_suffix
                    )
274
                    if ref_block.has_var(serial_input_varname):
275 276 277 278 279 280 281
                        _partition_var(
                            self._dist_context,
                            ref_block,
                            target_block,
                            serial_input_varname,
                            new_varname,
                        )
282
                    else:
283
                        for varname_not_in_block in __varname_not_in_block__:
284 285 286
                            assert (
                                varname_not_in_block in serial_input_varname
                            ), "{} is not found".format(serial_input_varname)
287 288

                    self._serial2dist_varname_mapping[
289 290
                        serial_input_varname
                    ] = new_varname
291 292 293

            # partition output vars
            for serial_output_varname in op.desc.output_arg_names():
294 295 296 297 298 299 300 301 302 303 304 305 306 307
                if (
                    serial_output_varname
                    not in self._serial2dist_varname_mapping
                ):
                    new_varname = (
                        serial_output_varname + self._dist_varname_suffix
                    )
                    _partition_var(
                        self._dist_context,
                        ref_block,
                        target_block,
                        serial_output_varname,
                        new_varname,
                    )
308
                    self._serial2dist_varname_mapping[
309 310
                        serial_output_varname
                    ] = new_varname
311 312

            # partition op
313
            if is_forward_op(op) or op_dist_attr.is_recompute:
314 315
                kinputs, koutputs = dist_op_context.prepare_context(op)
                dist_op_forward_impl = _get_dist_op_forward_implement(
316 317 318 319 320
                    op, self._dist_context
                )
                dist_op_forward_impl.forward(
                    self._dist_context, **kinputs, **koutputs
                )
321 322 323 324

            elif is_backward_op(op):
                kinputs, koutputs = dist_op_context.prepare_context(op)
                dist_op_backward_impl = _get_dist_op_backward_implement(
325 326 327 328 329 330 331
                    op, self._dist_context, forward_op_id2forward_op
                )
                grad_var_to_var = (
                    self._dist_context.dist_op_context.grad_var_to_var[
                        appended_grad_times
                    ]
                )
332
                dist_op_backward_impl.backward(
333 334 335 336 337
                    self._dist_context,
                    **kinputs,
                    **koutputs,
                    **{"grad_var_to_var": grad_var_to_var}
                )
338
            elif is_optimize_op(op):
339
                # NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
340
                kinputs, koutputs = dist_op_context.prepare_context(op)
341
                dist_op_opt_impl = _get_dist_op_backward_implement(
342 343 344 345 346 347 348 349
                    op, self._dist_context, forward_op_id2forward_op
                )
                dist_op_opt_impl.backward(
                    self._dist_context,
                    **kinputs,
                    **koutputs,
                    **{"grad_var_to_var": {}}
                )
350
            else:
351
                raise NotImplementedError(
352 353 354 355
                    "partitioner only support forward and backward, optimize ops, but got {}".format(
                        str(op)
                    )
                )
356

357 358 359 360 361 362
    def _is_valid_annotated_program(self, program):

        # TODO (ZJ-LIANG) should check all block
        ops = program.global_block().ops
        vars_ = program.list_vars()
        op_dist_attrs = [
363
            self._dist_context.get_op_dist_attr_for_program(op) for op in ops
364 365
        ]
        var_dist_attrs = [
366
            self._dist_context.get_tensor_dist_attr_for_program(var)
367
            for var in vars_
Z
zhaoyingli 已提交
368
            if (var.type not in __no_shape_var_type__)
369 370
        ]

371 372 373 374 375 376
        all_ops_annotated = all(
            dist_attr is not None for dist_attr in op_dist_attrs
        )
        all_vars_annotated = all(
            dist_attr is not None for dist_attr in var_dist_attrs
        )
377 378 379

        return all_ops_annotated and all_vars_annotated

380 381 382 383 384 385 386 387
    def _get_dist_var_by_serial_var(self, serial_var, partitioned_main_prog):

        block_idx = serial_var.block.idx
        target_block = partitioned_main_prog.blocks[block_idx]
        dist_var_name = self._serial2dist_varname_mapping[serial_var.name]
        assert target_block.has_var(dist_var_name)
        return target_block.var(dist_var_name)

388 389 390 391

def _get_dist_shape(var, dist_attr):

    var_shape = var.shape
392 393
    mapping = dist_attr.dims_mapping
    mesh = dist_attr.process_mesh.topology
394 395 396
    if mapping == []:
        return var_shape

397 398 399
    assert len(var_shape) == len(
        mapping
    ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
400 401
        var_shape, mapping
    )
402 403 404 405 406
    new_shape = []
    for idx in range(len(var_shape)):
        if var_shape[idx] == -1 or mapping[idx] == -1:
            new_shape.append(var_shape[idx])
        else:
407 408 409 410 411
            assert (
                var_shape[idx] % mesh[mapping[idx]] == 0
            ), "un-event partition: var_shape[idx]=[{}], mesh[{}]".format(
                var_shape[idx], mesh[mapping[idx]]
            )
412 413 414 415 416
            new_shape.append(var_shape[idx] // mesh[mapping[idx]])

    return new_shape


417 418 419
def _partition_parameter(
    dist_context, src_var, dst_block, dst_varname, dst_shape
):
420
    # NOTE hack to copied Parameter
421
    # not initialized parameter, need to initialize it
422 423 424 425 426 427 428
    copied_kwargs = {}
    copied_kwargs['trainable'] = src_var.trainable
    copied_kwargs['optimize_attr'] = src_var.optimize_attr
    copied_kwargs['regularizer'] = src_var.regularizer
    copied_kwargs['do_model_average'] = src_var.do_model_average
    copied_kwargs['need_clip'] = src_var.need_clip

429 430 431 432 433 434 435 436 437 438 439 440 441
    param = Parameter(
        block=dst_block,
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer,
        **copied_kwargs
    )
442

443
    return param
444 445


446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
def _partition_intermediate_var(
    dist_context, src_var, dst_block, dst_varname, dst_shape
):
    var = dst_block.create_var(
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        persistable=src_var.persistable,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer,
    )
461

462
    return var
463 464


465 466 467
def _partition_var(
    dist_context, src_block, dst_block, src_varname, dst_varname
):
468 469 470 471 472
    """
    partition include: split + replicate
    """
    src_var = src_block.var(src_varname)

Z
zhaoyingli 已提交
473
    if src_var.type in __no_shape_var_type__:
474
        persist = getattr(src_var, 'persistable', False)
475 476 477 478 479 480
        new_var = dst_block.create_var(
            type=src_var.type,
            name=dst_varname,
            persistable=persist,
            stop_gradient=True,
        )
J
JZ-LIANG 已提交
481
        target_shape = None
482
    else:
483
        dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
484 485 486
        target_shape = _get_dist_shape(src_var, dist_attr)

        if isinstance(src_var, Parameter):
487 488 489
            new_var = _partition_parameter(
                dist_context, src_var, dst_block, dst_varname, target_shape
            )
490
        else:
491 492 493
            new_var = _partition_intermediate_var(
                dist_context, src_var, dst_block, dst_varname, target_shape
            )
494 495

    dist_attr = copy.deepcopy(
496 497
        dist_context.get_tensor_dist_attr_for_program(src_var)
    )
498 499 500
    assert dist_attr is not None
    dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr)

J
JZ-LIANG 已提交
501
    return target_shape
502 503


504 505 506
def _get_dist_op_backward_implement(
    backward_op, dist_context, forward_op_id2forward_op
):
507
    dist_op_context = dist_context.dist_op_context
508 509
    if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
        forward_op_id = dist_op_context.grad_op_id_to_op_id[
510 511
            backward_op.desc.original_id()
        ]
512 513
        forward_op = forward_op_id2forward_op[forward_op_id]
        forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
514 515
            forward_op
        )
516
        dist_op_impl_container = get_distributed_operator_impl_container(
517 518
            forward_op_dist_attr.impl_type
        )
519
        dist_op_impl = dist_op_impl_container.get_impl(
520 521
            forward_op_dist_attr.impl_idx
        )
522
        return dist_op_impl
523

524
    # # NOTE trick for dist ops that only have backward implement
J
JZ-LIANG 已提交
525 526
    if backward_op.type in BACKWARD_ONLY_DIST_OPS:
        op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
527 528
        assert op_dist_attr.impl_idx >= 0
        dist_op_impl = get_distributed_operator_impl_container(
529 530
            op_dist_attr.impl_type
        ).get_impl(op_dist_attr.impl_idx)
531
        return dist_op_impl
J
JZ-LIANG 已提交
532 533 534

    dist_op = get_distributed_operator_impl_container("default")
    return dist_op.get_impl(0)
535 536 537 538


def _get_dist_op_forward_implement(forward_op, dist_context):
    dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
539
    dist_op_impl_container = get_distributed_operator_impl_container(
540 541
        dist_attr.impl_type
    )
542 543
    dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx)
    return dist_op_impl