dist_embedding.py 26.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
from paddle.common_ops_import import check_dtype, check_variable_and_dtype
16 17 18
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
    AllreduceSumOpCost,
    IdentityOpCost,
19
)
20
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
21 22
from paddle.framework import core
from paddle.utils import unique_name
23 24 25 26 27 28 29 30 31 32

from ..cost import (
    EmbeddingGradOpCost,
    EmbeddingOpCost,
    build_comm_costs_from_descs,
    build_comm_desc_from_dist_op,
    build_comp_costs_from_descs,
    build_comp_desc_from_dist_op,
    build_dp_costs,
)
33
from ..dist_attribute import OperatorDistAttr
34
from ..process_group import new_process_group
35 36 37
from ..utils import (
    _get_comm_group,
    _get_corresponding_rank,
38 39 40 41
    _get_idx_in_axis,
    compute_compatible_and_update_dim_mapping,
    is_dim_replicate,
    is_dim_shard,
42 43
    set_var_dist_attr,
)
44 45 46 47 48 49 50 51 52
from .common import (
    DistributedOperatorImpl,
    DistributedOperatorImplContainer,
    gradient_synchronization,
    infer_shape,
    naive_copy_op_dist_attr_for_program,
    register_distributed_operator_impl,
    register_distributed_operator_impl_container,
    set_comm_op_dist_attr_for_program,
53
)
54 55


56
class DistributedEmbedding(DistributedOperatorImplContainer):
57
    def __init__(self, op_type):
58
        super().__init__(op_type)
59 60


61
register_distributed_operator_impl_container(
62 63
    DistributedEmbedding("lookup_table_v2")
)
64
register_distributed_operator_impl_container(
65 66
    DistributedEmbedding("c_embedding")
)
67
register_distributed_operator_impl_container(
68 69
    DistributedEmbedding("lookup_table")
)
70 71 72 73


def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):

74 75 76 77 78
    assert (
        len(Ids_var.shape) == 3
    ), "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format(
        Ids_var.name, Ids_var.shape
    )
79 80
    if not Ids_var.stop_gradient:
        raise NotImplementedError(
C
chenxujun 已提交
81
            'Requiring the gradient of Ids of lookup_table(v1) dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).'
82 83 84 85
        )

    target_shape = list(Ids_var.shape[:-1])
    intermediate_var_0 = main_block.create_var(
86 87 88
        name=unique_name.generate_with_ignorable_key(
            ".".join(["dist_reshape", 'tmp'])
        ),
89 90 91 92
        dtype=Ids_var.dtype,
        shape=target_shape,
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
93 94
        stop_gradient=True,
    )
95 96 97

    target_shape = [0] + list(Ids_var.shape[:-1])
    xshape_var = main_block.create_var(
98 99 100
        name=unique_name.generate_with_ignorable_key(
            ".".join(["dist_Xshape", 'tmp'])
        ),
101 102 103 104
        dtype=Ids_var.dtype,
        shape=target_shape,
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
105 106
        stop_gradient=True,
    )
107 108

    # TODO use inplace reshape for memory saving
109 110 111 112 113 114 115 116
    reshape_op = main_block.append_op(
        type='reshape2',
        inputs={'X': [Ids_var]},
        outputs={'Out': [intermediate_var_0], 'XShape': [xshape_var]},
        attrs={
            "shape": [0, -1],
        },
    )
117 118 119 120 121 122

    # set dist attr
    op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
    Ids_var_dist_attr = op_dist_attr.get_input_dist_attr(Ids_var.name)
    assert Ids_var_dist_attr is not None
    intermediate_var_0_dist_attr = set_var_dist_attr(
123 124 125 126 127 128 129 130 131 132 133
        ctx,
        intermediate_var_0,
        Ids_var_dist_attr.dims_mapping,
        Ids_var_dist_attr.process_mesh,
    )
    set_var_dist_attr(
        ctx,
        xshape_var,
        [-1] + list(Ids_var_dist_attr.dims_mapping),
        Ids_var_dist_attr.process_mesh,
    )
134
    op_dist_attr.del_input_dist_attr(Ids_var.name)
135 136 137
    op_dist_attr.set_input_dist_attr(
        intermediate_var_0.name, intermediate_var_0_dist_attr
    )
138

139
    new_op_dist_attr = OperatorDistAttr()
140 141 142
    new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh
    new_op_dist_attr.impl_type = "default"
    new_op_dist_attr.impl_idx = 0
143 144 145 146 147 148
    new_op_dist_attr.set_input_dims_mapping(
        Ids_var.name, Ids_var_dist_attr.dims_mapping
    )
    new_op_dist_attr.set_output_dims_mapping(
        intermediate_var_0.name, Ids_var_dist_attr.dims_mapping
    )
149
    new_op_dist_attr.set_output_dims_mapping(
150 151
        xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping)
    )
152 153 154
    ctx.set_op_dist_attr_for_program(reshape_op, new_op_dist_attr)

    return intermediate_var_0
155 156 157 158 159


# RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl):
    def __init__(self, name):
160
        super().__init__(name)
161
        self._forward_implemented = True
162
        self._backward_implemented = True
163

C
caozhou 已提交
164 165 166 167 168 169 170 171 172 173 174 175
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        """Calculate the cost by the op role."""
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
176 177 178
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
179
        processes = dist_op.dist_attr.process_mesh.process_ids
C
caozhou 已提交
180
        # embedding need start_index
181 182 183
        cost_mapping = build_comp_costs_from_descs(
            EmbeddingOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
184 185 186

        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
187 188
            serial_op.input("W")[0]
        )[0]
C
caozhou 已提交
189 190 191 192 193 194 195 196
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
197 198
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
199 200

        comm_op_cost_list = build_comm_costs_from_descs(
201 202 203 204 205 206
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
C
caozhou 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        main_block = backward_op.block
        dist_attr = dist_op.dist_attr

        embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
220 221
            backward_op.input("W")[0]
        )[0]
C
caozhou 已提交
222 223 224 225 226 227 228 229 230
        parallel_axis = embedding_row_dim_mapping
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = [backward_op.input("Out@GRAD")[0]]
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
231 232
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
233 234

        process_mesh = dist_attr.process_mesh
235
        processes = process_mesh.process_ids
C
caozhou 已提交
236
        comm_op_cost_list = build_comm_costs_from_descs(
237 238
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
239 240 241
        res.append(comm_op_cost_list)

        # calc comp op cost
242 243 244 245 246 247
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            EmbeddingGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
248 249 250 251
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
252 253
            backward_op.input("Ids")[0]
        )
254
        mesh_shape = process_mesh.shape
C
caozhou 已提交
255 256 257 258 259
        batch_size_axis = var_dim_mapping[0]
        if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('W@GRAD')[0]]
260 261 262
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
263 264 265

        return res

266 267 268
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
269 270 271 272
        ids_name = op_desc.input('Ids')[0]
        w_name = op_desc.input('W')[0]
        ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
        w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
273
        if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(
274 275
            w_dims_mapping[-1]
        ):
276 277 278 279 280
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in ids_dims_mapping[1:]:
            if is_dim_shard(mapping):
                return False
281 282 283 284 285 286

        if is_dim_shard(ids_dims_mapping[0]) and is_dim_shard(
            w_dims_mapping[-2]
        ):
            if ids_dims_mapping[0] == w_dims_mapping[-2]:
                return False
287 288
        return True

289 290 291
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
292 293 294 295 296 297 298 299
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:]:
            if is_dim_shard(mapping):
                return False
        return True

沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
300
    def is_auto_compatible(self, dist_op):
301 302 303
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
304 305
            return False

沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
306 307 308 309 310 311 312 313
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        ids_name = op_desc.input('Ids')[0]
        w_name = op_desc.input('W')[0]
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
        w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
314

315
        if ids_dims_mapping != out_dims_mapping[: len(ids_dims_mapping)]:
沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
316 317 318 319
            return False

        return True

320
    def update_dims_mapping(self, dist_op):
321
        changed = False
322 323
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
324 325 326 327 328 329 330 331 332
        ids_name = op_desc.input('Ids')[0]
        w_name = op_desc.input('W')[0]
        out_name = op_desc.output('Out')[0]
        ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
        w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        for i in range(len(ids_dims_mapping)):
            dim_changed = compute_compatible_and_update_dim_mapping(
333 334
                [ids_dims_mapping, out_dims_mapping], [i, i]
            )
335 336 337 338
            if dim_changed:
                changed = True

        dim_changed = compute_compatible_and_update_dim_mapping(
339 340
            [w_dims_mapping, out_dims_mapping], [-1, -1]
        )
341 342 343
        if dim_changed:
            changed = True

344 345 346 347 348
        if changed:
            op_dist_attr.set_input_dims_mapping(ids_name, ids_dims_mapping)
            op_dist_attr.set_input_dims_mapping(w_name, w_dims_mapping)
            op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)

349 350
        return changed

351 352 353 354 355 356
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

357
        dist_op_context = ctx.dist_op_context
358 359 360 361
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
362
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
363 364
        assert (
            op_dist_attr is not None
365
        ), f"backward op [{str(src_op)}] don't have dist attribute !"
366

367
        # check validation of inputs / outputs
368 369 370 371
        assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
        assert 'W' in kwargs, "input [{}] is not given".format('W')
        assert 'Out' in kwargs, "output [{}] is not given".format('Out')

372 373 374
        assert (
            len(kwargs['Ids']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
375
            kwargs['Ids']
376 377 378 379
        )
        assert (
            len(kwargs['W']) == 1
        ), "row_parallel_embedding input W take 1 variable but got {}".format(
380
            kwargs['W']
381 382 383 384
        )
        assert (
            len(kwargs['Out']) == 1
        ), "row_parallel_embedding output Out take 1 variable but got {}".format(
385
            kwargs['Out']
386
        )
387

Z
zhaoyingli 已提交
388
        Ids_var = main_block._var_recursive(kwargs['Ids'][0])
389
        Weight_var = main_block._var_recursive(kwargs['W'][0])
Z
zhaoyingli 已提交
390
        Out_var = main_block._var_recursive(kwargs['Out'][0])
391

392 393 394 395
        # support lookup_table_v1
        if src_op.type == 'lookup_table':
            Ids_var = adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var)

396 397
        # got dist attribute info
        embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
398 399 400 401 402 403 404
            Weight_var.name
        )[0]
        assert (
            embedding_row_dim_mapping >= 0
        ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
            embedding_row_dim_mapping
        )
405 406
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
407 408 409

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
        if rank_id not in process_mesh_group:
410 411 412
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
413

C
chenxujun 已提交
414
        # A generalized method to calculate embedding offset using cartisian product
415 416 417 418 419 420
        relative_idx = _get_idx_in_axis(
            process_mesh_group,
            process_mesh_shape,
            embedding_row_dim_mapping,
            rank_id,
        )
421 422 423 424

        per_part_size = Weight_var.shape[0]
        relative_idx = relative_idx * per_part_size

C
chenxujun 已提交
425
        # TODO calculate ring id
426
        parallel_axis = embedding_row_dim_mapping
427 428 429
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
430 431 432
        group = new_process_group(group_ranks)

        # append op
433 434 435
        check_variable_and_dtype(
            Ids_var, 'input', ['int32', 'int64'], 'c_embedding'
        )
436

Z
zhaoyingli 已提交
437 438 439 440 441
        # infer new var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
442 443 444
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
445

446
        intermediate_var_0 = main_block.create_var(
447 448 449
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_embedding", 'tmp'])
            ),
450 451 452 453
            dtype=Weight_var.dtype,
            shape=Out_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
454 455
            stop_gradient=Out_var.stop_gradient,
        )
Z
zhaoyingli 已提交
456
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
457 458 459
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
460 461

        check_variable_and_dtype(
462 463
            Out_var,
            'tensor',
464
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
465 466
            'c_allreduce_sum',
        )
467 468 469

        c_embedding_op = main_block.append_op(
            type='c_embedding',
470
            inputs={'Ids': [Ids_var], 'W': [Weight_var]},
471
            outputs={'Out': [intermediate_var_0]},
472 473
            attrs={
                "start_index": relative_idx,
474 475 476
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
477 478
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
479 480 481 482 483 484 485 486 487 488

        # use_model_parallel
        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': [intermediate_var_0]},
            outputs={'Out': [Out_var]},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
489 490 491
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
492 493 494 495 496
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
497
        embedding_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
498
        embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh
499
        embedding_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
500 501 502 503
        embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_embedding_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
504 505 506 507 508
                op_dist_attr
            )
            embedding_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
509 510 511
        output_varname = c_embedding_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
512 513 514 515 516
            op_dist_attr
        )
        embedding_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
517
        ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr)
518

Z
zhaoyingli 已提交
519
        # allreduce
520
        allreduce_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
521
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
522
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
523 524
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
525
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
526 527
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
528 529 530
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
531 532 533
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
534 535 536 537 538 539 540 541
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
542 543

        # param initialization sync
544
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
545 546
            if Weight_var.name in dist_op_context.already_init_sync_vars:
                return
J
JZ-LIANG 已提交
547 548 549 550 551 552
            dist_op_context.already_init_sync_vars.add(Weight_var.name)
            param = startup_block.var(Weight_var.name)
            param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
            process_mesh = param_dist_attr.process_mesh
            dim_mapping = param_dist_attr.dims_mapping

C
chenxujun 已提交
553
            # NOTE all not splitted axis should be presented in mesh
554
            for axis, size in enumerate(process_mesh.shape):
J
JZ-LIANG 已提交
555 556 557
                if size <= 1 or axis in dim_mapping:
                    pass
                else:
558
                    group_ranks = _get_comm_group(
559 560
                        process_mesh.process_ids,
                        process_mesh.shape,
561 562 563
                        axis,
                        rank_id,
                    )
J
JZ-LIANG 已提交
564 565
                    sync_group = new_process_group(group_ranks)

566 567 568 569 570 571 572 573 574 575 576
                    startup_block.append_op(
                        type='c_broadcast',
                        inputs={'X': param},
                        outputs={'Out': param},
                        attrs={
                            'ring_id': sync_group.id,
                            'root': 0,
                            'use_calc_stream': True,
                            OP_ROLE_KEY: OpRole.Forward,
                        },
                    )
577 578 579 580 581

    @staticmethod
    def backward(ctx, *args, **kwargs):

        # by now the backward function only insert the gradient allreduce for dist op itself
582
        dist_op_context = ctx.dist_op_context
583 584 585
        main_block = dist_op_context.work_block
        backward_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
586
        dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
587 588 589 590 591
        assert (
            dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(
            str(backward_op)
        )
592

593
        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
594
        if rank_id not in dist_attr.process_mesh.process_ids:
595 596 597
            rank_id = _get_corresponding_rank(
                ctx, dist_attr.process_mesh, rank_id
            )
598 599 600 601 602 603

        assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
        assert 'W' in kwargs, "input [{}] is not given".format('W')
        assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out')
        assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD')

604 605 606
        assert (
            len(kwargs['Ids']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
607
            kwargs['Ids']
608 609 610 611
        )
        assert (
            len(kwargs['W']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
612
            kwargs['W']
613 614 615 616 617 618 619 620 621
        )
        assert (
            len(kwargs['Out@GRAD']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
            kwargs['Out']
        )
        assert (
            len(kwargs['W@GRAD']) == 1
        ), "row_parallel_embedding output Ids take 1 variable but got {}".format(
622
            kwargs['W@GRAD']
623
        )
624

Z
zhaoyingli 已提交
625 626 627 628
        Ids_var = main_block._var_recursive(kwargs['Ids'][0])
        Weight_var = main_block._var_recursive(kwargs['W'][0])
        Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0])
        Weight_grad = main_block._var_recursive(kwargs['W@GRAD'][0])
629 630

        embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
631 632 633 634 635 636 637
            Weight_var.name
        )[0]
        assert (
            embedding_row_dim_mapping >= 0
        ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
            embedding_row_dim_mapping
        )
638 639
        process_mesh_shape = dist_attr.process_mesh.shape
        process_mesh_group = dist_attr.process_mesh.process_ids
640

C
chenxujun 已提交
641
        # A generalized method to calculate embedding offset using cartisian product
642 643 644 645 646 647
        relative_idx = _get_idx_in_axis(
            process_mesh_group,
            process_mesh_shape,
            embedding_row_dim_mapping,
            rank_id,
        )
648 649 650 651
        per_part_size = Weight_var.shape[0]
        relative_idx = relative_idx * per_part_size

        check_variable_and_dtype(
652 653
            Out_grad,
            'tensor',
654
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
655 656
            '_c_identity',
        )
657 658

        intermediate_var_0 = main_block.create_var(
659 660 661
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_embedding", '@tmp_0@GRAD'])
            ),
662 663 664 665
            dtype=Out_grad.dtype,
            shape=Out_grad.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
666 667
            stop_gradient=Out_grad.stop_gradient,
        )
668 669 670 671

        # copy X_var's dist_attr to intermediate_var_0's dist_attr
        out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
        assert out_grad_dist_attr is not None
672 673 674
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_grad_dist_attr
        )
675

676 677 678 679 680 681
        group_ranks = _get_comm_group(
            process_mesh_group,
            process_mesh_shape,
            embedding_row_dim_mapping,
            rank_id,
        )
682 683 684 685 686 687 688 689 690 691 692
        group = new_process_group(group_ranks)

        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [Out_grad]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
                OP_ROLE_KEY: OpRole.Backward,
693 694 695
            },
        )
        check_variable_and_dtype(
696 697 698 699
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
700 701 702 703
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
704
            ['float16', 'float32', 'float64', 'uint16'],
705 706
            'linear',
        )
707

708 709 710
        set_comm_op_dist_attr_for_program(
            c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
        )
711

712
        c_embedding_grad_op_desc = main_block.append_op(type='nop').desc
713 714 715
        c_embedding_grad_op_desc.set_type("c_embedding_grad")
        c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
        c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
716 717 718
        c_embedding_grad_op_desc.set_input(
            'Out@GRAD', [intermediate_var_0.name]
        )
719 720 721 722 723 724
        c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
        c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
        c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)

        c_embedding_grad_op = main_block.ops[-1]
        assert c_embedding_grad_op.type == "c_embedding_grad"
725 726 727
        naive_copy_op_dist_attr_for_program(
            c_embedding_grad_op, backward_op, ctx
        )
728

729 730 731
        # data parallel gradient synchronization
        act_grad_names = [Ids_var.name]
        out_grad_names = [kwargs['W@GRAD'][0]]
732

733 734 735
        gradient_synchronization(
            ctx, backward_op, act_grad_names, out_grad_names, rank_id
        )
736

737

738 739 740 741 742 743 744 745 746
register_distributed_operator_impl(
    "lookup_table_v2", DistributedEmbeddingImpl("row_parallel")
)
register_distributed_operator_impl(
    "c_embedding", DistributedEmbeddingImpl("row_parallel")
)
register_distributed_operator_impl(
    "lookup_table", DistributedEmbeddingImpl("row_parallel")
)