dist_default.py 25.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
from .common import DistributedOperatorImplContainer
16
from .common import DistributedOperatorImpl
17
from .common import register_distributed_operator_impl_container
18
from .common import gradient_synchronization
J
JZ-LIANG 已提交
19
from .common import register_distributed_operator_impl, is_parameter_related
20 21
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
22
from ..utils import is_valid_list_index, is_prim_op
23 24 25
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
26
from ..utils import set_dist_op_desc_original_id
27
from ..dist_attribute import OperatorDistributedAttribute
28
from paddle.fluid import core, unique_name
J
Jiabin Yang 已提交
29
from paddle.fluid.framework import _non_static_mode
30 31
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
32 33 34 35 36
from paddle.distributed.fleet.meta_optimizers.common import (
    OpRole,
    OP_ROLE_KEY,
    OP_ROLE_VAR_KEY,
)
37
from ..process_group import new_process_group
38
from ..utils import _get_comm_group, _get_corresponding_rank
39 40 41
from ..cost import _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
42

43
__op_not_need_param_init__ = ["while", "cond"]
44
__op_has_shape_attr__ = ["fill_constant_batch_size_like", "fill_constant"]
45

46

47 48 49 50 51 52 53
def prim_operator_data_parallel_functor(ctx, src_op):
    dist_op_context = ctx.dist_op_context
    main_block = dist_op_context.work_block
    startup_block = dist_op_context.startup_block

    var_name = src_op.output_arg_names[0]
    if var_name in ctx.grads_params:
54 55 56
        assert (
            var_name not in ctx.synced_gradient
        ), "in primtive mode, grad is already {} synced".format(var_name)
57 58 59
        ctx.synced_gradient.add(var_name)
        sync_group = new_process_group(ctx.data_parallel_group)

60 61 62 63 64 65 66 67 68 69
        allreduce_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': [var_name]},
            outputs={'Out': [var_name]},
            attrs={
                'ring_id': sync_group.id,
                'use_calc_stream': True,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
70 71 72

        param = ctx.grads_params[var_name]
        startup_block = dist_op_context.startup_block
73 74 75 76 77 78 79 80 81 82 83
        new_op = startup_block.append_op(
            type='c_broadcast',
            inputs={'X': [param]},
            outputs={'Out': [param]},
            attrs={
                'ring_id': sync_group.id,
                'root': 0,
                'use_calc_stream': True,
                OP_ROLE_KEY: OpRole.Forward,
            },
        )
84 85 86

        grad_var = main_block.var(var_name)
        dims_mapping = ctx.get_tensor_dist_attr_for_program(
87 88
            grad_var
        ).dims_mapping
89 90 91 92 93 94 95 96 97 98 99
        dist_attr = ctx.get_op_dist_attr_for_program(src_op)
        process_mesh = dist_attr.process_mesh
        op_attr = OperatorDistributedAttribute()
        op_attr.process_mesh = process_mesh
        op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
        op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
        ctx.set_op_dist_attr_for_program(allreduce_op, op_attr)

    return


100
class DistributedDefault(DistributedOperatorImplContainer):
101 102
    def __init__(self, op_type):
        super(DistributedDefault, self).__init__(op_type)
103 104


105
register_distributed_operator_impl_container(DistributedDefault("default"))
106 107


108
# Replicated Default
109 110
class DistributedDefaultImpl0(DistributedOperatorImpl):
    def __init__(self, name):
111
        super(DistributedDefaultImpl0, self).__init__(name)
112 113 114
        self._forward_implemented = True
        self._backward_implemented = True

115 116 117 118 119 120 121 122 123 124 125 126
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        """Calculate the cost by the op role."""
        cost = None
        if int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        else:
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
127 128 129
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
130 131
        processes = dist_op.dist_attr.process_mesh.processes
        op_type = dist_op.serial_op.type
132 133 134
        cost_mapping = build_comp_costs_from_descs(
            _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
        )
135 136 137 138 139 140 141
        res_cost = [cost_mapping]

        return res_cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
        res = []
142 143 144
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
145 146 147 148 149
        dist_attr = dist_op.dist_attr
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
        backward_op = dist_op.serial_op
        op_type = backward_op.type
150 151 152
        cost_mapping = build_comp_costs_from_descs(
            _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
        )
153 154 155 156 157 158 159 160
        res.append(cost_mapping)

        main_block = backward_op.block
        vars = main_block.vars
        need_gradient_allreduce = False
        for input_name in backward_op.desc.input_names():
            for varname in backward_op.desc.input(input_name):
                if "@GRAD" not in varname and not is_parameter_related(
161 162
                    varname, main_block
                ):
163 164 165 166 167 168 169 170 171 172 173
                    var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
                    mesh_shape = process_mesh.topology
                    batch_size_axis = var_dim_mapping[0]
                    if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
                        need_gradient_allreduce = True
                        break

        if need_gradient_allreduce:
            for input_name in backward_op.desc.input_names():
                for varname in backward_op.desc.input(input_name):
                    if "@GRAD" not in varname and is_parameter_related(
174 175
                        varname, main_block
                    ):
176
                        var_dim_mapping = dist_attr.get_input_dims_mapping(
177 178
                            varname
                        )
179 180 181 182 183
                        mesh_shape = process_mesh.topology
                        batch_size_axis = var_dim_mapping[0]
                        parallel_axis = batch_size_axis
                        attrs = {"use_calc_stream": True}
                        var_names = [varname + "@GRAD"]
184 185 186 187 188 189 190 191 192
                        build_dp_costs(
                            res,
                            dist_op,
                            ctx,
                            var_names,
                            attrs,
                            parallel_axis,
                            cluster,
                        )
193 194
        return res

195
    def is_input_compatible(self, dist_op):
196 197
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
198
        batch_dim_mappings = []
199 200 201 202
        input_names = op_desc.input_names()
        xshape_arg_names = []
        if "XShape" in input_names:
            xshape_arg_names = op_desc.input("XShape")
203 204 205
        for arg_name in op_desc.input_arg_names():
            serial_tensor = dist_op.get_serial_input(arg_name)
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
206 207 208 209
            if serial_tensor.is_parameter:
                for mapping in dims_mapping:
                    if mapping != -1:
                        return False
210
                continue
211 212 213 214 215
            if arg_name not in xshape_arg_names:
                if len(dims_mapping) > 1:
                    for mapping in dims_mapping[1:]:
                        if mapping != -1:
                            return False
216 217
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
218 219 220 221 222 223 224
            else:
                if dims_mapping[0] != -1:
                    return False
                if len(dims_mapping) > 2:
                    for mapping in dims_mapping[2:]:
                        if mapping != -1:
                            return False
225 226 227 228 229 230
                if len(dims_mapping) >= 2:
                    batch_dim_mappings.append(dims_mapping[1])

        if compute_compatible_dim_mapping(batch_dim_mappings) is None:
            return False

231
        return True
232

233
    def is_output_compatible(self, dist_op):
234 235 236
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        output_names = op_desc.output_names()
237
        batch_dim_mappings = []
238 239 240 241 242 243
        xshape_arg_names = []
        if "XShape" in output_names:
            xshape_arg_names = op_desc.output("XShape")
        for arg_name in op_desc.output_arg_names():
            serial_tensor = dist_op.get_serial_output(arg_name)
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
244 245 246 247
            if serial_tensor.is_parameter:
                for mapping in dims_mapping:
                    if mapping != -1:
                        return False
248
                continue
249 250 251 252 253
            if arg_name not in xshape_arg_names:
                if len(dims_mapping) > 1:
                    for mapping in dims_mapping[1:]:
                        if mapping != -1:
                            return False
254 255
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
256 257 258 259 260 261 262
            else:
                if dims_mapping[0] != -1:
                    return False
                if len(dims_mapping) > 2:
                    for mapping in dims_mapping[2:]:
                        if mapping != -1:
                            return False
263 264 265 266 267 268
                if len(dims_mapping) >= 2:
                    batch_dim_mappings.append(dims_mapping[1])

        if compute_compatible_dim_mapping(batch_dim_mappings) is None:
            return False

269 270 271 272 273 274 275
        return True

    def is_auto_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        batch_dim_mappings = []
        # Check input compatibility
276 277 278 279
        input_names = op_desc.input_names()
        xshape_arg_names = []
        if "XShape" in input_names:
            xshape_arg_names = op_desc.input("XShape")
280 281
        for arg_name in op_desc.input_arg_names():
            serial_tensor = dist_op.get_serial_input(arg_name)
282
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
283
            if serial_tensor is not None and serial_tensor.is_parameter:
284 285 286
                for mapping in dims_mapping:
                    if mapping != -1:
                        return False
287
                continue
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
            if arg_name not in xshape_arg_names:
                if len(dims_mapping) > 1:
                    for mapping in dims_mapping[1:]:
                        if mapping != -1:
                            return False
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
            else:
                if dims_mapping[0] != -1:
                    return False
                if len(dims_mapping) > 2:
                    for mapping in dims_mapping[2:]:
                        if mapping != -1:
                            return False
                if len(dims_mapping) >= 2:
                    batch_dim_mappings.append(dims_mapping[1])
304 305 306 307 308 309 310 311

        # Check output compatibility
        output_names = op_desc.output_names()
        xshape_arg_names = []
        if "XShape" in output_names:
            xshape_arg_names = op_desc.output("XShape")
        for arg_name in op_desc.output_arg_names():
            serial_tensor = dist_op.get_serial_output(arg_name)
312
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
313
            if serial_tensor is not None and serial_tensor.is_parameter:
314 315 316
                for mapping in dims_mapping:
                    if mapping != -1:
                        return False
317 318 319 320 321 322
                continue
            if arg_name not in xshape_arg_names:
                if len(dims_mapping) > 1:
                    for mapping in dims_mapping[1:]:
                        if mapping != -1:
                            return False
323 324
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
325 326 327 328 329 330 331
            else:
                if dims_mapping[0] != -1:
                    return False
                if len(dims_mapping) > 2:
                    for mapping in dims_mapping[2:]:
                        if mapping != -1:
                            return False
332 333
                if len(dims_mapping) >= 2:
                    batch_dim_mappings.append(dims_mapping[1])
334 335

        # Check batch dim mapping compatibility
336 337 338 339
        if not all(
            batch_dim_mappings[0] == dim_mapping
            for dim_mapping in batch_dim_mappings
        ):
340 341 342
            return False

        return True
343

344
    def update_dims_mapping(self, dist_op):
345 346 347
        changed = False
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
348 349

        if op_desc.type() == "while":
350
            return False
351 352 353 354 355 356

        input_names = op_desc.input_names()
        input_xshape_arg_names = []
        if "XShape" in input_names:
            input_xshape_arg_names = op_desc.input("XShape")

357
        output_names = op_desc.output_names()
358
        output_xshape_arg_names = []
359
        if "XShape" in output_names:
360 361
            output_xshape_arg_names = op_desc.output("XShape")

362 363 364 365 366 367
        batch_dim_mappings = []
        for arg_name in op_desc.input_arg_names():
            serial_tensor = dist_op.get_serial_input(arg_name)
            if serial_tensor.is_parameter:
                continue
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
368 369 370 371 372
            if arg_name not in input_xshape_arg_names:
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
            else:
                batch_dim_mappings.append(dims_mapping[1])
373
        for arg_name in op_desc.output_arg_names():
374
            if op_desc.type() == 'fill_any_like':
375
                input_tensor = dist_op.get_serial_input(
376 377
                    op_desc.input_arg_names()[0]
                )
378 379
                if input_tensor.is_parameter:
                    continue
380 381 382 383
            serial_tensor = dist_op.get_serial_output(arg_name)
            if serial_tensor.is_parameter:
                continue
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
384
            if arg_name not in output_xshape_arg_names:
385 386
                if len(dims_mapping) >= 1:
                    batch_dim_mappings.append(dims_mapping[0])
387 388 389
            else:
                batch_dim_mappings.append(dims_mapping[1])

390 391 392
        if not batch_dim_mappings:
            return changed

393
        compatible_dim_mapping = compute_compatible_dim_mapping(
394 395
            batch_dim_mappings
        )
396 397 398
        if compatible_dim_mapping is None:
            return False

399 400 401 402 403
        for arg_name in op_desc.input_arg_names():
            serial_tensor = dist_op.get_serial_input(arg_name)
            if serial_tensor.is_parameter:
                continue
            dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
404
            if arg_name not in input_xshape_arg_names:
405 406 407 408
                if (
                    len(dims_mapping) >= 1
                    and compatible_dim_mapping != dims_mapping[0]
                ):
409 410 411
                    dims_mapping[0] = compatible_dim_mapping
                    changed = True
            else:
412 413 414 415
                if (
                    len(dims_mapping) >= 2
                    and compatible_dim_mapping != dims_mapping[1]
                ):
416 417
                    dims_mapping[1] = compatible_dim_mapping
                    changed = True
418
        for arg_name in op_desc.output_arg_names():
419
            if op_desc.type() == 'fill_any_like':
420
                input_tensor = dist_op.get_serial_input(
421 422
                    op_desc.input_arg_names()[0]
                )
423 424
                if input_tensor.is_parameter:
                    continue
425 426
            if op_desc.type() in ["shape", "slice"]:
                continue
427 428 429 430
            serial_tensor = dist_op.get_serial_output(arg_name)
            if serial_tensor.is_parameter:
                continue
            dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
431
            if arg_name not in output_xshape_arg_names:
432 433 434 435
                if (
                    len(dims_mapping) >= 1
                    and compatible_dim_mapping != dims_mapping[0]
                ):
436 437 438
                    dims_mapping[0] = compatible_dim_mapping
                    changed = True
            else:
439 440 441 442
                if (
                    len(dims_mapping) >= 2
                    and compatible_dim_mapping != dims_mapping[1]
                ):
443 444 445 446
                    dims_mapping[1] = compatible_dim_mapping
                    changed = True

        return changed
447 448 449

    @staticmethod
    def forward(ctx, *args, **kwargs):
450
        dist_op_context = ctx.dist_op_context
451 452 453 454
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
455

456
        # check validation of inputs / outputs
457 458
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
459 460
                input_name
            )
461 462 463 464 465
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
466 467
                output_name
            )
468 469 470
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
471 472
                output_name
            )
473 474

        # replicate op in dist program
475
        dist_op_desc = main_block.append_op(type='nop').desc
476
        dist_op_desc.copy_from(src_op.desc)
477
        set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
478 479 480 481 482
        for input_name in src_op.desc.input_names():
            dist_op_desc.set_input(input_name, kwargs[input_name])
        for output_name in src_op.desc.output_names():
            dist_op_desc.set_output(output_name, kwargs[output_name])

483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
        if (
            src_op.has_attr('shape')
            and src_op.attr('shape')
            and src_op.type in __op_has_shape_attr__
        ):
            shape_list = src_op.attr('shape')
            Out_var = main_block._var_recursive(kwargs['Out'][0])
            op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
            dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
            process_mesh_shape = op_dist_attr.process_mesh.shape
            assert len(shape_list) == len(dim_mapping)
            # modify target shape
            for idx, axis in enumerate(dim_mapping):
                if axis >= 0:
                    if len(shape_list) > idx:
                        shape_list[idx] = (
                            shape_list[idx] // process_mesh_shape[axis]
                        )
            dist_op_desc._set_attr('shape', shape_list)

503 504
        # data parallel synchronization for primtive operators
        from paddle.incubate.autograd import prim_enabled
505

506 507 508 509
        if prim_enabled():
            assert is_prim_op(src_op)
            prim_operator_data_parallel_functor(ctx, src_op)
            return
510 511

        # param initialization sync
512 513 514
        if src_op.type in __op_not_need_param_init__:
            return

515
        for varname in dist_op_desc.input_arg_names():
516 517 518 519 520
            if (
                startup_block.has_var(varname)
                and startup_block.var(varname).is_parameter
                and varname not in dist_op_context.already_init_sync_vars
            ):
521
                dist_op_context.already_init_sync_vars.add(varname)
522
                param = startup_block.var(varname)
523 524 525
                param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
                process_mesh = param_dist_attr.process_mesh
                dims_mapping = param_dist_attr.dims_mapping
526 527

                # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
528
                if rank_id not in process_mesh.processes:
529 530 531
                    rank_id = _get_corresponding_rank(
                        ctx, process_mesh, rank_id
                    )
532

533
                # NOTE all not splited axis should be presented in mesh
534 535 536 537
                for axis, size in enumerate(process_mesh.topology):
                    if size <= 1 or axis in dims_mapping:
                        pass
                    else:
538 539 540 541 542 543
                        group_ranks = _get_comm_group(
                            process_mesh.processes,
                            process_mesh.topology,
                            axis,
                            rank_id,
                        )
544 545
                        sync_group = new_process_group(group_ranks)

546 547 548 549 550 551 552 553 554 555 556
                        new_op = startup_block.append_op(
                            type='c_broadcast',
                            inputs={'X': param},
                            outputs={'Out': param},
                            attrs={
                                'ring_id': sync_group.id,
                                'root': 0,
                                'use_calc_stream': True,
                                OP_ROLE_KEY: OpRole.Forward,
                            },
                        )
557 558

                        # set distributed attribute
559 560
                        op_attr = OperatorDistributedAttribute()
                        op_attr.process_mesh = process_mesh
561 562 563
                        op_attr.set_output_dims_mapping(
                            param.name, dims_mapping
                        )
564
                        op_attr.set_input_dims_mapping(param.name, dims_mapping)
565
                        ctx.set_op_dist_attr_for_program(new_op, op_attr)
566 567 568 569 570

    @staticmethod
    def backward(ctx, *args, **kwargs):

        # by now the backward function only insert the gradient allreduce for dist op itself
571
        dist_op_context = ctx.dist_op_context
572 573
        main_block = dist_op_context.work_block
        backward_op = dist_op_context.cur_src_op
574
        dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
575 576 577 578 579
        assert (
            dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(
            str(backward_op)
        )
580
        rank_id = dist_op_context.rank_id
581

582 583 584
        # check validation of inputs / outputs
        for input_name in backward_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
585 586
                input_name
            )
587 588 589 590 591
            assert len(kwargs[input_name]) == len(
                backward_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in backward_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
592 593
                output_name
            )
594 595 596
            assert len(kwargs[output_name]) == len(
                backward_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
597 598
                output_name
            )
599 600

        # replicate op in dist program
601
        dist_op_desc = main_block.append_op(type='nop').desc
602
        dist_op_desc.copy_from(backward_op.desc)
603 604
        # Refer to the related dist op
        set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
605 606 607 608 609
        for input_name in backward_op.desc.input_names():
            dist_op_desc.set_input(input_name, kwargs[input_name])
        for output_name in backward_op.desc.output_names():
            dist_op_desc.set_output(output_name, kwargs[output_name])

610 611
        # data parallel gradient synchronization
        act_grad_names = []
612 613
        for input_name in backward_op.desc.input_names():
            for varname in backward_op.desc.input(input_name):
J
JZ-LIANG 已提交
614
                if "@GRAD" not in varname and not is_parameter_related(
615 616
                    varname, main_block
                ):
617
                    act_grad_names.append(varname)
618

619 620 621 622 623 624 625 626 627 628
        out_grad_names = []
        for output_name in backward_op.desc.output_names():
            for varname in backward_op.desc.output(output_name):
                if varname in kwargs["grad_var_to_var"]:
                    fwd_name = kwargs["grad_var_to_var"][varname]
                    if fwd_name not in main_block.vars:
                        continue
                    if is_parameter_related(fwd_name, main_block):
                        out_grad_names.append(varname)

629 630 631
        gradient_synchronization(
            ctx, backward_op, act_grad_names, out_grad_names, rank_id
        )
632 633 634


register_distributed_operator_impl(
635 636
    "default", DistributedDefaultImpl0("replicate_parallel")
)