partitioner.py 20.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
import paddle.fluid as fluid
from paddle.fluid import core
18 19
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program
20 21 22
from paddle.distributed.auto_parallel.operators.common import (
    get_distributed_operator_impl_container,
)
23
from paddle.distributed.auto_parallel.dist_context import DistributedContext
24
from .dist_attribute import OperatorDistributedAttribute
25
from .utils import is_backward_op, is_forward_op, is_loss_op, is_optimize_op
J
JZ-LIANG 已提交
26
from .operators.common import BACKWARD_ONLY_DIST_OPS
27

28
__varname_not_in_block__ = ["lod_tensor_blocking_queue"]
29
__not_shape_var_type__ = [
30 31
    core.VarDesc.VarType.READER,
    core.VarDesc.VarType.STEP_SCOPES,
32
]
33 34


35
class Partitioner:
36 37 38 39
    """
    warning:: Partitioner is experimental and subject to change.

    Partitioner convert a program into another program.
40
    Given a serial program which has been auto completed with shard annotation, the Partitioner
41 42 43 44 45 46 47 48
    convert the serial program into a "distributed" program. The Partitioner will  modify the serial
    program in following two ways, which is also the major difference between serial and distributed program:
        1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
        2. partition var: if a var is sharded, modify the shape of var according to its shard annotation

    Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user.
    """

49
    def __init__(self, dist_context, rank_id=0):
50 51
        """
        Args:
52
            dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
53 54
            rank_id (int): global rank id to which the partitioned distributed program belong.
        """
55
        if not isinstance(dist_context, DistributedContext):
56
            raise TypeError(
57 58 59
                "dist_context be paddle.fluid.DistributedContext, got %s here"
                % type(dist_context)
            )
60

61
        self._dist_context = dist_context
62 63 64 65
        self._rank_id = rank_id
        self._serial2dist_varname_mapping = {}
        self._dist_varname_suffix = ""

66 67 68
    def partition(
        self, serial_main_program, serial_startup_program, params_grads
    ):
69
        if not isinstance(serial_main_program, (Program)):
70
            raise TypeError(
71 72 73
                "main_program be paddle.fluid.framework.program, got %s here"
                % type(serial_main_program)
            )
74 75

        # check if shard annotated serial program valid
76
        if not self._is_valid_annotated_program(serial_main_program):
77
            raise RuntimeError(
78 79
                "Not all vars or ops are annotated in main program !"
            )
80

81 82
        # init distop helper
        dist_op_context = self._dist_context.dist_op_context
83 84
        dist_op_context.varname_mapping = self._serial2dist_varname_mapping
        dist_op_context.rank_id = self._rank_id
85

86
        # partition startup program
87
        if serial_startup_program is None:
88 89 90
            partitioned_startup_prog = None
        else:
            partitioned_startup_prog = self.partition_startup_program(
91 92
                serial_main_program, serial_startup_program
            )
93
        dist_op_context.dst_startup_program = partitioned_startup_prog
94

95
        # partition main program
96 97 98 99
        (
            partitioned_main_prog,
            partitioned_params_grads,
        ) = self.partition_main_program(serial_main_program, params_grads)
100

101 102 103 104 105
        return (
            partitioned_main_prog,
            partitioned_startup_prog,
            partitioned_params_grads,
        )
106

107 108 109
    def partition_startup_program(
        self, serial_main_program, serial_startup_program
    ):
110

111 112
        if not isinstance(serial_startup_program, (Program)):
            raise TypeError(
113 114 115
                "dist_context be paddle.fluid.framework.program, got %s here"
                % type(serial_startup_program)
            )
116

117 118 119
        partitioned_startup_prog = fluid.Program()
        ref_block = serial_main_program.global_block()
        target_block = partitioned_startup_prog.global_block()
J
JZ-LIANG 已提交
120
        var2shape = {}
121
        temp_varname_map = {}
122

123 124
        # tensors
        for var in serial_startup_program.list_vars():
J
JZ-LIANG 已提交
125 126 127
            assert var.persistable
            new_name = var.name + self._dist_varname_suffix
            temp_varname_map[var.name] = new_name
128 129 130
            target_shape = _partition_var(
                self._dist_context, ref_block, target_block, var.name, new_name
            )
J
JZ-LIANG 已提交
131
            var2shape[new_name] = target_shape
132 133 134 135 136

        # ops
        for op in serial_startup_program.global_block().ops:
            # TODO if var not belong to this rank, should be filtered
            output_vars = op.desc.output_arg_names()
137 138 139 140 141 142 143 144 145 146
            assert (
                len(output_vars) == 1
            ), "initializer should output only ONE variable, but got [{}]".format(
                str(op.desc)
            )
            assert (
                temp_varname_map[output_vars[0]] in var2shape
            ), "try to initialize [{}] which is not a persistable var".format(
                output_vars[0]
            )
147 148
            new_op_desc = target_block.desc.append_op()
            new_op_desc.copy_from(op.desc)
149 150 151 152 153 154
            new_op_desc._rename_output(
                output_vars[0], temp_varname_map[output_vars[0]]
            )
            new_op_desc._set_attr(
                "shape", var2shape[temp_varname_map[output_vars[0]]]
            )
155 156 157 158 159 160 161
            target_block._sync_with_cpp()

            # set distribute atrribute
            new_op = target_block.ops[-1]
            assert new_op.type == new_op_desc.type()
            assert new_op.desc == new_op_desc
            output_var = target_block.var(output_vars[0])
162 163 164
            output_var_attr = (
                self._dist_context.get_tensor_dist_attr_for_program(output_var)
            )
165 166
            op_attr = OperatorDistributedAttribute()
            op_attr.process_mesh = output_var_attr.process_mesh
167 168 169 170 171 172
            op_attr.set_output_dims_mapping(
                output_var.name, output_var_attr.dims_mapping
            )
            op_attr.set_input_dims_mapping(
                output_var.name, output_var_attr.dims_mapping
            )
173 174 175 176 177
            self._dist_context.set_op_dist_attr_for_program(new_op, op_attr)

        return partitioned_startup_prog

    def partition_main_program(self, serial_main_program, params_and_grads):
178 179 180 181 182 183
        """
        1. partition variables
        2. replace local op with corresponding dist op
        """

        partitioned_main_prog = fluid.Program()
184 185 186 187 188 189 190 191 192 193
        dist_op_context = self._dist_context.dist_op_context
        dist_op_context.dst_main_program = partitioned_main_prog

        for idx in range(self._dist_context.block_state.nblock):
            ref_block = serial_main_program.blocks[idx]

            if idx == 0:
                target_block = partitioned_main_prog.blocks[0]
            else:
                target_block = partitioned_main_prog._create_block(
194 195
                    parent_idx=ref_block.parent_idx
                )
196 197 198 199 200 201 202
                assert ref_block.idx == target_block.idx
                target_block._set_forward_block_idx(ref_block.forward_block_idx)
            dist_op_context.work_block = target_block
            self.partition_block(ref_block, target_block)

        partitioned_main_prog.current_block_idx = 0

203 204 205 206 207 208 209
        # should reconnect the block_attr ptr to the correct block
        for block_id in range(self._dist_context.block_state.nblock):
            block = partitioned_main_prog.block(block_id)
            for op in block.ops:
                for attr_name in op.all_attrs():
                    if op.attr_type(attr_name) == core.AttrType.BLOCK:
                        relative_id = op._block_attr_id(attr_name)
210 211 212
                        op._set_attr(
                            attr_name, partitioned_main_prog.block(relative_id)
                        )
213

214 215 216 217 218 219 220 221
        partitioned_params_and_grads = []
        for p, g in params_and_grads:
            assert p.name in self._serial2dist_varname_mapping
            dist_p = self._get_dist_var_by_serial_var(p, partitioned_main_prog)
            if g is None:
                dist_g = None
            else:
                assert g.name in self._serial2dist_varname_mapping
222
                dist_g = self._get_dist_var_by_serial_var(
223 224
                    g, partitioned_main_prog
                )
225 226 227 228 229 230 231 232
            partitioned_params_and_grads.append((dist_p, dist_g))

        return partitioned_main_prog, partitioned_params_and_grads

    def partition_block(self, ref_block, target_block):

        dist_op_context = self._dist_context.dist_op_context
        serial_ops = ref_block.ops
233

234 235 236 237 238 239 240 241 242
        last_fwd_op_idx = -1
        for idx, op in enumerate(ref_block.ops):
            if is_loss_op(op):
                last_fwd_op_idx = idx
                break

        if last_fwd_op_idx == -1:
            last_fwd_op_idx = len(ref_block.ops)

243 244 245
        # init mapping
        forward_op_id2forward_op = {}
        for idx in range(len(serial_ops)):
246
            if idx <= last_fwd_op_idx:
247
                forward_op_id2forward_op[
248 249
                    serial_ops[idx].desc.original_id()
                ] = serial_ops[idx]
250

251
        # partiiton
Z
zhaoyingli 已提交
252
        appended_grad_times = 0
253 254
        for idx, op in enumerate(serial_ops):

Z
zhaoyingli 已提交
255
            op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
256 257 258 259
            if is_backward_op(op) and (
                is_forward_op(serial_ops[idx - 1])
                or is_loss_op(serial_ops[idx - 1])
            ):
Z
zhaoyingli 已提交
260 261
                if not op_dist_attr.is_recompute:
                    appended_grad_times += 1
262 263 264

            # partititon input variables
            for serial_input_varname in op.desc.input_arg_names():
265 266 267 268 269 270 271
                if (
                    serial_input_varname
                    not in self._serial2dist_varname_mapping
                ):
                    new_varname = (
                        serial_input_varname + self._dist_varname_suffix
                    )
272
                    if ref_block.has_var(serial_input_varname):
273 274 275 276 277 278 279
                        _partition_var(
                            self._dist_context,
                            ref_block,
                            target_block,
                            serial_input_varname,
                            new_varname,
                        )
280
                    else:
281
                        for varname_not_in_block in __varname_not_in_block__:
282 283 284
                            assert (
                                varname_not_in_block in serial_input_varname
                            ), "{} is not found".format(serial_input_varname)
285 286

                    self._serial2dist_varname_mapping[
287 288
                        serial_input_varname
                    ] = new_varname
289 290 291

            # partition output vars
            for serial_output_varname in op.desc.output_arg_names():
292 293 294 295 296 297 298 299 300 301 302 303 304 305
                if (
                    serial_output_varname
                    not in self._serial2dist_varname_mapping
                ):
                    new_varname = (
                        serial_output_varname + self._dist_varname_suffix
                    )
                    _partition_var(
                        self._dist_context,
                        ref_block,
                        target_block,
                        serial_output_varname,
                        new_varname,
                    )
306
                    self._serial2dist_varname_mapping[
307 308
                        serial_output_varname
                    ] = new_varname
309 310

            # partition op
311
            if is_forward_op(op) or op_dist_attr.is_recompute:
312 313
                kinputs, koutputs = dist_op_context.prepare_context(op)
                dist_op_forward_impl = _get_dist_op_forward_implement(
314 315 316 317 318
                    op, self._dist_context
                )
                dist_op_forward_impl.forward(
                    self._dist_context, **kinputs, **koutputs
                )
319 320 321 322

            elif is_backward_op(op):
                kinputs, koutputs = dist_op_context.prepare_context(op)
                dist_op_backward_impl = _get_dist_op_backward_implement(
323 324 325 326 327 328 329
                    op, self._dist_context, forward_op_id2forward_op
                )
                grad_var_to_var = (
                    self._dist_context.dist_op_context.grad_var_to_var[
                        appended_grad_times
                    ]
                )
330
                dist_op_backward_impl.backward(
331 332 333 334 335
                    self._dist_context,
                    **kinputs,
                    **koutputs,
                    **{"grad_var_to_var": grad_var_to_var}
                )
336
            elif is_optimize_op(op):
337
                # NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
338
                kinputs, koutputs = dist_op_context.prepare_context(op)
339
                dist_op_opt_impl = _get_dist_op_backward_implement(
340 341 342 343 344 345 346 347
                    op, self._dist_context, forward_op_id2forward_op
                )
                dist_op_opt_impl.backward(
                    self._dist_context,
                    **kinputs,
                    **koutputs,
                    **{"grad_var_to_var": {}}
                )
348
            else:
349
                raise NotImplementedError(
350 351 352 353
                    "partitioner only support forward and backward, optimize ops, but got {}".format(
                        str(op)
                    )
                )
354

355 356 357 358 359 360
    def _is_valid_annotated_program(self, program):

        # TODO (ZJ-LIANG) should check all block
        ops = program.global_block().ops
        vars_ = program.list_vars()
        op_dist_attrs = [
361
            self._dist_context.get_op_dist_attr_for_program(op) for op in ops
362 363
        ]
        var_dist_attrs = [
364
            self._dist_context.get_tensor_dist_attr_for_program(var)
365 366
            for var in vars_
            if (var.type not in __not_shape_var_type__)
367 368
        ]

369 370 371 372 373 374
        all_ops_annotated = all(
            dist_attr is not None for dist_attr in op_dist_attrs
        )
        all_vars_annotated = all(
            dist_attr is not None for dist_attr in var_dist_attrs
        )
375 376 377

        return all_ops_annotated and all_vars_annotated

378 379 380 381 382 383 384 385
    def _get_dist_var_by_serial_var(self, serial_var, partitioned_main_prog):

        block_idx = serial_var.block.idx
        target_block = partitioned_main_prog.blocks[block_idx]
        dist_var_name = self._serial2dist_varname_mapping[serial_var.name]
        assert target_block.has_var(dist_var_name)
        return target_block.var(dist_var_name)

386 387 388 389

def _get_dist_shape(var, dist_attr):

    var_shape = var.shape
390 391
    mapping = dist_attr.dims_mapping
    mesh = dist_attr.process_mesh.topology
392 393 394
    if mapping == []:
        return var_shape

395 396 397
    assert len(var_shape) == len(
        mapping
    ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
398 399
        var_shape, mapping
    )
400 401 402 403 404
    new_shape = []
    for idx in range(len(var_shape)):
        if var_shape[idx] == -1 or mapping[idx] == -1:
            new_shape.append(var_shape[idx])
        else:
405 406 407 408 409
            assert (
                var_shape[idx] % mesh[mapping[idx]] == 0
            ), "un-event partition: var_shape[idx]=[{}], mesh[{}]".format(
                var_shape[idx], mesh[mapping[idx]]
            )
410 411 412 413 414
            new_shape.append(var_shape[idx] // mesh[mapping[idx]])

    return new_shape


415 416 417
def _partition_parameter(
    dist_context, src_var, dst_block, dst_varname, dst_shape
):
418
    # NOTE hack to copied Parameter
419
    # not initialized parameter, need to initialize it
420 421 422 423 424 425 426
    copied_kwargs = {}
    copied_kwargs['trainable'] = src_var.trainable
    copied_kwargs['optimize_attr'] = src_var.optimize_attr
    copied_kwargs['regularizer'] = src_var.regularizer
    copied_kwargs['do_model_average'] = src_var.do_model_average
    copied_kwargs['need_clip'] = src_var.need_clip

427 428 429 430 431 432 433 434 435 436 437 438 439
    param = Parameter(
        block=dst_block,
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer,
        **copied_kwargs
    )
440

441
    return param
442 443


444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
def _partition_intermediate_var(
    dist_context, src_var, dst_block, dst_varname, dst_shape
):
    var = dst_block.create_var(
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        persistable=src_var.persistable,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer,
    )
459

460
    return var
461 462


463 464 465
def _partition_var(
    dist_context, src_block, dst_block, src_varname, dst_varname
):
466 467 468 469 470
    """
    partition include: split + replicate
    """
    src_var = src_block.var(src_varname)

471
    if src_var.type in __not_shape_var_type__:
472
        persist = getattr(src_var, 'persistable', False)
473 474 475 476 477 478
        new_var = dst_block.create_var(
            type=src_var.type,
            name=dst_varname,
            persistable=persist,
            stop_gradient=True,
        )
J
JZ-LIANG 已提交
479
        target_shape = None
480
    else:
481
        dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
482 483 484
        target_shape = _get_dist_shape(src_var, dist_attr)

        if isinstance(src_var, Parameter):
485 486 487
            new_var = _partition_parameter(
                dist_context, src_var, dst_block, dst_varname, target_shape
            )
488
        else:
489 490 491
            new_var = _partition_intermediate_var(
                dist_context, src_var, dst_block, dst_varname, target_shape
            )
492 493

    dist_attr = copy.deepcopy(
494 495
        dist_context.get_tensor_dist_attr_for_program(src_var)
    )
496 497 498
    assert dist_attr is not None
    dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr)

J
JZ-LIANG 已提交
499
    return target_shape
500 501


502 503 504
def _get_dist_op_backward_implement(
    backward_op, dist_context, forward_op_id2forward_op
):
505
    dist_op_context = dist_context.dist_op_context
506 507
    if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
        forward_op_id = dist_op_context.grad_op_id_to_op_id[
508 509
            backward_op.desc.original_id()
        ]
510 511
        forward_op = forward_op_id2forward_op[forward_op_id]
        forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
512 513
            forward_op
        )
514
        dist_op_impl_container = get_distributed_operator_impl_container(
515 516
            forward_op_dist_attr.impl_type
        )
517
        dist_op_impl = dist_op_impl_container.get_impl(
518 519
            forward_op_dist_attr.impl_idx
        )
520
        return dist_op_impl
521

522
    # # NOTE trick for dist ops that only have backward implement
J
JZ-LIANG 已提交
523 524
    if backward_op.type in BACKWARD_ONLY_DIST_OPS:
        op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
525 526
        assert op_dist_attr.impl_idx >= 0
        dist_op_impl = get_distributed_operator_impl_container(
527 528
            op_dist_attr.impl_type
        ).get_impl(op_dist_attr.impl_idx)
529
        return dist_op_impl
J
JZ-LIANG 已提交
530 531 532

    dist_op = get_distributed_operator_impl_container("default")
    return dist_op.get_impl(0)
533 534 535 536


def _get_dist_op_forward_implement(forward_op, dist_context):
    dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
537
    dist_op_impl_container = get_distributed_operator_impl_container(
538 539
        dist_attr.impl_type
    )
540 541
    dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx)
    return dist_op_impl