dist_embedding.py 26.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
from paddle.common_ops_import import check_dtype, check_variable_and_dtype
16 17 18
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
    AllreduceSumOpCost,
    IdentityOpCost,
19
)
20
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
21 22
from paddle.framework import core
from paddle.utils import unique_name
23 24 25 26 27 28 29 30 31 32

from ..cost import (
    EmbeddingGradOpCost,
    EmbeddingOpCost,
    build_comm_costs_from_descs,
    build_comm_desc_from_dist_op,
    build_comp_costs_from_descs,
    build_comp_desc_from_dist_op,
    build_dp_costs,
)
33
from ..dist_attribute import OperatorDistAttr
34
from ..process_group import new_process_group
35 36 37
from ..utils import (
    _get_comm_group,
    _get_corresponding_rank,
38 39 40 41
    _get_idx_in_axis,
    compute_compatible_and_update_dim_mapping,
    is_dim_replicate,
    is_dim_shard,
42 43
    set_var_dist_attr,
)
44 45 46 47 48 49 50 51 52
from .common import (
    DistributedOperatorImpl,
    DistributedOperatorImplContainer,
    gradient_synchronization,
    infer_shape,
    naive_copy_op_dist_attr_for_program,
    register_distributed_operator_impl,
    register_distributed_operator_impl_container,
    set_comm_op_dist_attr_for_program,
53
)
54 55


56
class DistributedEmbedding(DistributedOperatorImplContainer):
57
    def __init__(self, op_type):
58
        super().__init__(op_type)
59 60


61
register_distributed_operator_impl_container(
62 63
    DistributedEmbedding("lookup_table_v2")
)
64
register_distributed_operator_impl_container(
65 66
    DistributedEmbedding("c_embedding")
)
67
register_distributed_operator_impl_container(
68 69
    DistributedEmbedding("lookup_table")
)
70 71 72 73


def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):

74 75 76 77 78
    assert (
        len(Ids_var.shape) == 3
    ), "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format(
        Ids_var.name, Ids_var.shape
    )
79 80
    if not Ids_var.stop_gradient:
        raise NotImplementedError(
C
chenxujun 已提交
81
            'Requiring the gradient of Ids of lookup_table(v1) dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).'
82 83 84 85
        )

    target_shape = list(Ids_var.shape[:-1])
    intermediate_var_0 = main_block.create_var(
86 87 88
        name=unique_name.generate_with_ignorable_key(
            ".".join(["dist_reshape", 'tmp'])
        ),
89 90 91 92
        dtype=Ids_var.dtype,
        shape=target_shape,
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
93 94
        stop_gradient=True,
    )
95 96 97

    target_shape = [0] + list(Ids_var.shape[:-1])
    xshape_var = main_block.create_var(
98 99 100
        name=unique_name.generate_with_ignorable_key(
            ".".join(["dist_Xshape", 'tmp'])
        ),
101 102 103 104
        dtype=Ids_var.dtype,
        shape=target_shape,
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
105 106
        stop_gradient=True,
    )
107 108

    # TODO use inplace reshape for memory saving
109 110 111 112 113 114 115 116
    reshape_op = main_block.append_op(
        type='reshape2',
        inputs={'X': [Ids_var]},
        outputs={'Out': [intermediate_var_0], 'XShape': [xshape_var]},
        attrs={
            "shape": [0, -1],
        },
    )
117 118 119 120 121 122

    # set dist attr
    op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
    Ids_var_dist_attr = op_dist_attr.get_input_dist_attr(Ids_var.name)
    assert Ids_var_dist_attr is not None
    intermediate_var_0_dist_attr = set_var_dist_attr(
123 124 125 126 127 128 129 130 131 132 133
        ctx,
        intermediate_var_0,
        Ids_var_dist_attr.dims_mapping,
        Ids_var_dist_attr.process_mesh,
    )
    set_var_dist_attr(
        ctx,
        xshape_var,
        [-1] + list(Ids_var_dist_attr.dims_mapping),
        Ids_var_dist_attr.process_mesh,
    )
134
    op_dist_attr.del_input_dist_attr(Ids_var.name)
135 136 137
    op_dist_attr.set_input_dist_attr(
        intermediate_var_0.name, intermediate_var_0_dist_attr
    )
138

139
    new_op_dist_attr = OperatorDistAttr()
140 141 142
    new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh
    new_op_dist_attr.impl_type = "default"
    new_op_dist_attr.impl_idx = 0
143 144 145 146 147 148
    new_op_dist_attr.set_input_dims_mapping(
        Ids_var.name, Ids_var_dist_attr.dims_mapping
    )
    new_op_dist_attr.set_output_dims_mapping(
        intermediate_var_0.name, Ids_var_dist_attr.dims_mapping
    )
149
    new_op_dist_attr.set_output_dims_mapping(
150 151
        xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping)
    )
152 153 154
    ctx.set_op_dist_attr_for_program(reshape_op, new_op_dist_attr)

    return intermediate_var_0
155 156 157 158 159


# RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl):
    def __init__(self, name):
160
        super().__init__(name)
161
        self._forward_implemented = True
162
        self._backward_implemented = True
163

C
caozhou 已提交
164 165 166 167 168 169 170 171 172 173 174 175
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        """Calculate the cost by the op role."""
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
176 177 178
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
179
        processes = dist_op.dist_attr.process_mesh.process_ids
C
caozhou 已提交
180
        # embedding need start_index
181 182 183
        cost_mapping = build_comp_costs_from_descs(
            EmbeddingOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
184 185 186

        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
187 188
            serial_op.input("W")[0]
        )[0]
C
caozhou 已提交
189 190 191 192 193 194 195 196
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
197 198
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
199 200

        comm_op_cost_list = build_comm_costs_from_descs(
201 202 203 204 205 206
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
C
caozhou 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        main_block = backward_op.block
        dist_attr = dist_op.dist_attr

        embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
220 221
            backward_op.input("W")[0]
        )[0]
C
caozhou 已提交
222 223 224 225 226 227 228 229 230
        parallel_axis = embedding_row_dim_mapping
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = [backward_op.input("Out@GRAD")[0]]
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
231 232
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
233 234

        process_mesh = dist_attr.process_mesh
235
        processes = process_mesh.process_ids
C
caozhou 已提交
236
        comm_op_cost_list = build_comm_costs_from_descs(
237 238
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
239 240 241
        res.append(comm_op_cost_list)

        # calc comp op cost
242 243 244 245 246 247
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            EmbeddingGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
248 249 250 251
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
252 253
            backward_op.input("Ids")[0]
        )
254
        mesh_shape = process_mesh.shape
C
caozhou 已提交
255 256 257 258 259
        batch_size_axis = var_dim_mapping[0]
        if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('W@GRAD')[0]]
260 261 262
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
263 264 265

        return res

266 267 268
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
269 270 271 272
        ids_name = op_desc.input('Ids')[0]
        w_name = op_desc.input('W')[0]
        ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
        w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
273
        if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(
274 275
            w_dims_mapping[-1]
        ):
276 277 278 279 280 281 282
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in ids_dims_mapping[1:]:
            if is_dim_shard(mapping):
                return False
        return True

283 284 285
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
286 287 288 289 290 291 292 293
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:]:
            if is_dim_shard(mapping):
                return False
        return True

沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
294
    def is_auto_compatible(self, dist_op):
295 296 297
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
298 299
            return False

沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
300 301 302 303 304 305 306 307
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        ids_name = op_desc.input('Ids')[0]
        w_name = op_desc.input('W')[0]
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
        w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
308

309
        if ids_dims_mapping != out_dims_mapping[: len(ids_dims_mapping)]:
沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
310 311 312 313
            return False

        return True

314
    def update_dims_mapping(self, dist_op):
315
        changed = False
316 317
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
318 319 320 321 322 323 324 325 326
        ids_name = op_desc.input('Ids')[0]
        w_name = op_desc.input('W')[0]
        out_name = op_desc.output('Out')[0]
        ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
        w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        for i in range(len(ids_dims_mapping)):
            dim_changed = compute_compatible_and_update_dim_mapping(
327 328
                [ids_dims_mapping, out_dims_mapping], [i, i]
            )
329 330 331 332
            if dim_changed:
                changed = True

        dim_changed = compute_compatible_and_update_dim_mapping(
333 334
            [w_dims_mapping, out_dims_mapping], [-1, -1]
        )
335 336 337
        if dim_changed:
            changed = True

338 339 340 341 342
        if changed:
            op_dist_attr.set_input_dims_mapping(ids_name, ids_dims_mapping)
            op_dist_attr.set_input_dims_mapping(w_name, w_dims_mapping)
            op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)

343 344
        return changed

345 346 347 348 349 350
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

351
        dist_op_context = ctx.dist_op_context
352 353 354 355
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
356
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
357 358 359
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
360

361
        # check validation of inputs / outputs
362 363 364 365
        assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
        assert 'W' in kwargs, "input [{}] is not given".format('W')
        assert 'Out' in kwargs, "output [{}] is not given".format('Out')

366 367 368
        assert (
            len(kwargs['Ids']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
369
            kwargs['Ids']
370 371 372 373
        )
        assert (
            len(kwargs['W']) == 1
        ), "row_parallel_embedding input W take 1 variable but got {}".format(
374
            kwargs['W']
375 376 377 378
        )
        assert (
            len(kwargs['Out']) == 1
        ), "row_parallel_embedding output Out take 1 variable but got {}".format(
379
            kwargs['Out']
380
        )
381

Z
zhaoyingli 已提交
382
        Ids_var = main_block._var_recursive(kwargs['Ids'][0])
383
        Weight_var = main_block._var_recursive(kwargs['W'][0])
Z
zhaoyingli 已提交
384
        Out_var = main_block._var_recursive(kwargs['Out'][0])
385

386 387 388 389
        # support lookup_table_v1
        if src_op.type == 'lookup_table':
            Ids_var = adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var)

390 391
        # got dist attribute info
        embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
392 393 394 395 396 397 398
            Weight_var.name
        )[0]
        assert (
            embedding_row_dim_mapping >= 0
        ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
            embedding_row_dim_mapping
        )
399 400
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
401 402 403

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
        if rank_id not in process_mesh_group:
404 405 406
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
407

C
chenxujun 已提交
408
        # A generalized method to calculate embedding offset using cartisian product
409 410 411 412 413 414
        relative_idx = _get_idx_in_axis(
            process_mesh_group,
            process_mesh_shape,
            embedding_row_dim_mapping,
            rank_id,
        )
415 416 417 418

        per_part_size = Weight_var.shape[0]
        relative_idx = relative_idx * per_part_size

C
chenxujun 已提交
419
        # TODO calculate ring id
420
        parallel_axis = embedding_row_dim_mapping
421 422 423
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
424 425 426
        group = new_process_group(group_ranks)

        # append op
427 428 429
        check_variable_and_dtype(
            Ids_var, 'input', ['int32', 'int64'], 'c_embedding'
        )
430

Z
zhaoyingli 已提交
431 432 433 434 435
        # infer new var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
436 437 438
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
439

440
        intermediate_var_0 = main_block.create_var(
441 442 443
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_embedding", 'tmp'])
            ),
444 445 446 447
            dtype=Weight_var.dtype,
            shape=Out_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
448 449
            stop_gradient=Out_var.stop_gradient,
        )
Z
zhaoyingli 已提交
450
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
451 452 453
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
454 455

        check_variable_and_dtype(
456 457
            Out_var,
            'tensor',
458
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
459 460
            'c_allreduce_sum',
        )
461 462 463

        c_embedding_op = main_block.append_op(
            type='c_embedding',
464
            inputs={'Ids': [Ids_var], 'W': [Weight_var]},
465
            outputs={'Out': [intermediate_var_0]},
466 467
            attrs={
                "start_index": relative_idx,
468 469 470
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
471 472
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
473 474 475 476 477 478 479 480 481 482

        # use_model_parallel
        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': [intermediate_var_0]},
            outputs={'Out': [Out_var]},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
483 484 485
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
486 487 488 489 490
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
491
        embedding_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
492
        embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh
493
        embedding_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
494 495 496 497
        embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_embedding_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
498 499 500 501 502
                op_dist_attr
            )
            embedding_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
503 504 505
        output_varname = c_embedding_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
506 507 508 509 510
            op_dist_attr
        )
        embedding_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
511
        ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr)
512

Z
zhaoyingli 已提交
513
        # allreduce
514
        allreduce_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
515
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
516
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
517 518
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
519
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
520 521
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
522 523 524
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
525 526 527
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
528 529 530 531 532 533 534 535
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
536 537

        # param initialization sync
538
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
539 540
            if Weight_var.name in dist_op_context.already_init_sync_vars:
                return
J
JZ-LIANG 已提交
541 542 543 544 545 546
            dist_op_context.already_init_sync_vars.add(Weight_var.name)
            param = startup_block.var(Weight_var.name)
            param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
            process_mesh = param_dist_attr.process_mesh
            dim_mapping = param_dist_attr.dims_mapping

C
chenxujun 已提交
547
            # NOTE all not splitted axis should be presented in mesh
548
            for axis, size in enumerate(process_mesh.shape):
J
JZ-LIANG 已提交
549 550 551
                if size <= 1 or axis in dim_mapping:
                    pass
                else:
552
                    group_ranks = _get_comm_group(
553 554
                        process_mesh.process_ids,
                        process_mesh.shape,
555 556 557
                        axis,
                        rank_id,
                    )
J
JZ-LIANG 已提交
558 559
                    sync_group = new_process_group(group_ranks)

560 561 562 563 564 565 566 567 568 569 570
                    startup_block.append_op(
                        type='c_broadcast',
                        inputs={'X': param},
                        outputs={'Out': param},
                        attrs={
                            'ring_id': sync_group.id,
                            'root': 0,
                            'use_calc_stream': True,
                            OP_ROLE_KEY: OpRole.Forward,
                        },
                    )
571 572 573 574 575

    @staticmethod
    def backward(ctx, *args, **kwargs):

        # by now the backward function only insert the gradient allreduce for dist op itself
576
        dist_op_context = ctx.dist_op_context
577 578 579
        main_block = dist_op_context.work_block
        backward_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
580
        dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
581 582 583 584 585
        assert (
            dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(
            str(backward_op)
        )
586

587
        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
588
        if rank_id not in dist_attr.process_mesh.process_ids:
589 590 591
            rank_id = _get_corresponding_rank(
                ctx, dist_attr.process_mesh, rank_id
            )
592 593 594 595 596 597

        assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
        assert 'W' in kwargs, "input [{}] is not given".format('W')
        assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out')
        assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD')

598 599 600
        assert (
            len(kwargs['Ids']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
601
            kwargs['Ids']
602 603 604 605
        )
        assert (
            len(kwargs['W']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
606
            kwargs['W']
607 608 609 610 611 612 613 614 615
        )
        assert (
            len(kwargs['Out@GRAD']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
            kwargs['Out']
        )
        assert (
            len(kwargs['W@GRAD']) == 1
        ), "row_parallel_embedding output Ids take 1 variable but got {}".format(
616
            kwargs['W@GRAD']
617
        )
618

Z
zhaoyingli 已提交
619 620 621 622
        Ids_var = main_block._var_recursive(kwargs['Ids'][0])
        Weight_var = main_block._var_recursive(kwargs['W'][0])
        Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0])
        Weight_grad = main_block._var_recursive(kwargs['W@GRAD'][0])
623 624

        embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
625 626 627 628 629 630 631
            Weight_var.name
        )[0]
        assert (
            embedding_row_dim_mapping >= 0
        ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
            embedding_row_dim_mapping
        )
632 633
        process_mesh_shape = dist_attr.process_mesh.shape
        process_mesh_group = dist_attr.process_mesh.process_ids
634

C
chenxujun 已提交
635
        # A generalized method to calculate embedding offset using cartisian product
636 637 638 639 640 641
        relative_idx = _get_idx_in_axis(
            process_mesh_group,
            process_mesh_shape,
            embedding_row_dim_mapping,
            rank_id,
        )
642 643 644 645
        per_part_size = Weight_var.shape[0]
        relative_idx = relative_idx * per_part_size

        check_variable_and_dtype(
646 647
            Out_grad,
            'tensor',
648
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
649 650
            '_c_identity',
        )
651 652

        intermediate_var_0 = main_block.create_var(
653 654 655
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_embedding", '@tmp_0@GRAD'])
            ),
656 657 658 659
            dtype=Out_grad.dtype,
            shape=Out_grad.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
660 661
            stop_gradient=Out_grad.stop_gradient,
        )
662 663 664 665

        # copy X_var's dist_attr to intermediate_var_0's dist_attr
        out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
        assert out_grad_dist_attr is not None
666 667 668
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_grad_dist_attr
        )
669

670 671 672 673 674 675
        group_ranks = _get_comm_group(
            process_mesh_group,
            process_mesh_shape,
            embedding_row_dim_mapping,
            rank_id,
        )
676 677 678 679 680 681 682 683 684 685 686
        group = new_process_group(group_ranks)

        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [Out_grad]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
                OP_ROLE_KEY: OpRole.Backward,
687 688 689
            },
        )
        check_variable_and_dtype(
690 691 692 693
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
694 695 696 697
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
698
            ['float16', 'float32', 'float64', 'uint16'],
699 700
            'linear',
        )
701

702 703 704
        set_comm_op_dist_attr_for_program(
            c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
        )
705

706
        c_embedding_grad_op_desc = main_block.append_op(type='nop').desc
707 708 709
        c_embedding_grad_op_desc.set_type("c_embedding_grad")
        c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
        c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
710 711 712
        c_embedding_grad_op_desc.set_input(
            'Out@GRAD', [intermediate_var_0.name]
        )
713 714 715 716 717 718
        c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
        c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
        c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)

        c_embedding_grad_op = main_block.ops[-1]
        assert c_embedding_grad_op.type == "c_embedding_grad"
719 720 721
        naive_copy_op_dist_attr_for_program(
            c_embedding_grad_op, backward_op, ctx
        )
722

723 724 725
        # data parallel gradient synchronization
        act_grad_names = [Ids_var.name]
        out_grad_names = [kwargs['W@GRAD'][0]]
726

727 728 729
        gradient_synchronization(
            ctx, backward_op, act_grad_names, out_grad_names, rank_id
        )
730

731

732 733 734 735 736 737 738 739 740
register_distributed_operator_impl(
    "lookup_table_v2", DistributedEmbeddingImpl("row_parallel")
)
register_distributed_operator_impl(
    "c_embedding", DistributedEmbeddingImpl("row_parallel")
)
register_distributed_operator_impl(
    "lookup_table", DistributedEmbeddingImpl("row_parallel")
)