dist_embedding.py 26.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

Z
zhaoyingli 已提交
15
from .common import infer_shape
16
from .common import DistributedOperatorImplContainer
17
from .common import DistributedOperatorImpl
18
from .common import register_distributed_operator_impl_container
19
from .common import gradient_synchronization
J
JZ-LIANG 已提交
20 21 22 23 24 25
from .common import (
    register_distributed_operator_impl,
    set_comm_op_dist_attr_for_program,
    naive_copy_op_dist_attr_for_program,
    is_parameter_related,
)
26 27 28 29 30 31
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
J
JZ-LIANG 已提交
32 33 34 35
from ..dist_attribute import (
    OperatorDistributedAttribute,
    TensorDistributedAttribute,
)
36
from paddle.fluid import core, unique_name
J
Jiabin Yang 已提交
37
from paddle.fluid.framework import _non_static_mode
J
JZ-LIANG 已提交
38
from paddle.fluid.framework import Program, Parameter, Variable
39
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
J
JZ-LIANG 已提交
40 41 42 43 44
from paddle.distributed.fleet.meta_optimizers.common import (
    OpRole,
    OP_ROLE_KEY,
    OP_ROLE_VAR_KEY,
)
45
from ..process_group import new_process_group
J
JZ-LIANG 已提交
46 47 48 49 50 51
from ..utils import (
    _get_comm_group,
    _get_idx_in_axis,
    _get_corresponding_rank,
    set_var_dist_attr,
)
C
caozhou 已提交
52
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
J
JZ-LIANG 已提交
53 54 55 56 57
from ..cost import (
    build_comm_costs_from_descs,
    build_comp_costs_from_descs,
    build_dp_costs,
)
58
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost
J
JZ-LIANG 已提交
59 60 61 62
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
    AllreduceSumOpCost,
    IdentityOpCost,
)
63 64


65
class DistributedEmbedding(DistributedOperatorImplContainer):
66 67
    def __init__(self, op_type):
        super(DistributedEmbedding, self).__init__(op_type)
68 69


70
register_distributed_operator_impl_container(
J
JZ-LIANG 已提交
71 72
    DistributedEmbedding("lookup_table_v2")
)
73
register_distributed_operator_impl_container(
J
JZ-LIANG 已提交
74 75
    DistributedEmbedding("c_embedding")
)
76
register_distributed_operator_impl_container(
J
JZ-LIANG 已提交
77 78
    DistributedEmbedding("lookup_table")
)
79 80 81 82


def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):

J
JZ-LIANG 已提交
83 84 85 86 87
    assert (
        len(Ids_var.shape) == 3
    ), "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format(
        Ids_var.name, Ids_var.shape
    )
88 89 90 91 92 93 94
    if not Ids_var.stop_gradient:
        raise NotImplementedError(
            'Requiring the gradient of Ids of lookup_table(v1)dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).'
        )

    target_shape = list(Ids_var.shape[:-1])
    intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
95 96 97
        name=unique_name.generate_with_ignorable_key(
            ".".join(["dist_reshape", 'tmp'])
        ),
98 99 100 101
        dtype=Ids_var.dtype,
        shape=target_shape,
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
J
JZ-LIANG 已提交
102 103
        stop_gradient=True,
    )
104 105 106

    target_shape = [0] + list(Ids_var.shape[:-1])
    xshape_var = main_block.create_var(
J
JZ-LIANG 已提交
107 108 109
        name=unique_name.generate_with_ignorable_key(
            ".".join(["dist_Xshape", 'tmp'])
        ),
110 111 112 113
        dtype=Ids_var.dtype,
        shape=target_shape,
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
J
JZ-LIANG 已提交
114 115
        stop_gradient=True,
    )
116 117

    # TODO use inplace reshape for memory saving
J
JZ-LIANG 已提交
118 119 120 121 122 123 124 125
    reshape_op = main_block.append_op(
        type='reshape2',
        inputs={'X': [Ids_var]},
        outputs={'Out': [intermediate_var_0], 'XShape': [xshape_var]},
        attrs={
            "shape": [0, -1],
        },
    )
126 127 128 129 130 131

    # set dist attr
    op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
    Ids_var_dist_attr = op_dist_attr.get_input_dist_attr(Ids_var.name)
    assert Ids_var_dist_attr is not None
    intermediate_var_0_dist_attr = set_var_dist_attr(
J
JZ-LIANG 已提交
132 133 134 135 136 137 138 139 140 141 142
        ctx,
        intermediate_var_0,
        Ids_var_dist_attr.dims_mapping,
        Ids_var_dist_attr.process_mesh,
    )
    set_var_dist_attr(
        ctx,
        xshape_var,
        [-1] + list(Ids_var_dist_attr.dims_mapping),
        Ids_var_dist_attr.process_mesh,
    )
143
    op_dist_attr.del_input_dist_attr(Ids_var.name)
J
JZ-LIANG 已提交
144 145 146
    op_dist_attr.set_input_dist_attr(
        intermediate_var_0.name, intermediate_var_0_dist_attr
    )
147 148 149 150 151

    new_op_dist_attr = OperatorDistributedAttribute()
    new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh
    new_op_dist_attr.impl_type = "default"
    new_op_dist_attr.impl_idx = 0
J
JZ-LIANG 已提交
152 153 154 155 156 157
    new_op_dist_attr.set_input_dims_mapping(
        Ids_var.name, Ids_var_dist_attr.dims_mapping
    )
    new_op_dist_attr.set_output_dims_mapping(
        intermediate_var_0.name, Ids_var_dist_attr.dims_mapping
    )
158
    new_op_dist_attr.set_output_dims_mapping(
J
JZ-LIANG 已提交
159 160
        xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping)
    )
161 162 163
    ctx.set_op_dist_attr_for_program(reshape_op, new_op_dist_attr)

    return intermediate_var_0
164 165 166 167 168


# RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl):
    def __init__(self, name):
169
        super(DistributedEmbeddingImpl, self).__init__(name)
170
        self._forward_implemented = True
171
        self._backward_implemented = True
172

C
caozhou 已提交
173 174 175 176 177 178 179 180 181 182 183 184
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        """Calculate the cost by the op role."""
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
J
JZ-LIANG 已提交
185 186 187
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
188 189
        processes = dist_op.dist_attr.process_mesh.processes
        # embedding need start_index
J
JZ-LIANG 已提交
190 191 192
        cost_mapping = build_comp_costs_from_descs(
            EmbeddingOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
193 194 195

        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
196 197
            serial_op.input("W")[0]
        )[0]
C
caozhou 已提交
198 199 200 201 202 203 204 205
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
206 207
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
208 209

        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
210 211 212 213 214 215
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
C
caozhou 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        main_block = backward_op.block
        dist_attr = dist_op.dist_attr

        embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
229 230
            backward_op.input("W")[0]
        )[0]
C
caozhou 已提交
231 232 233 234 235 236 237 238 239
        parallel_axis = embedding_row_dim_mapping
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = [backward_op.input("Out@GRAD")[0]]
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
240 241
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
242 243 244 245

        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
246 247
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
248 249 250
        res.append(comm_op_cost_list)

        # calc comp op cost
J
JZ-LIANG 已提交
251 252 253 254 255 256
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            EmbeddingGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
257 258 259 260
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
261 262
            backward_op.input("Ids")[0]
        )
C
caozhou 已提交
263 264 265 266 267 268
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
        if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('W@GRAD')[0]]
J
JZ-LIANG 已提交
269 270 271
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
272 273 274

        return res

275 276 277
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
278 279 280 281
        ids_name = op_desc.input('Ids')[0]
        w_name = op_desc.input('W')[0]
        ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
        w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
282
        if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(
J
JZ-LIANG 已提交
283 284
            w_dims_mapping[-1]
        ):
285 286 287 288 289 290 291
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in ids_dims_mapping[1:]:
            if is_dim_shard(mapping):
                return False
        return True

292 293 294
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
295 296 297 298 299 300 301 302
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:]:
            if is_dim_shard(mapping):
                return False
        return True

沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
303
    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
304 305 306
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
307 308
            return False

沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
309 310 311 312 313 314 315 316
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        ids_name = op_desc.input('Ids')[0]
        w_name = op_desc.input('W')[0]
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
        w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
317

J
JZ-LIANG 已提交
318
        if ids_dims_mapping != out_dims_mapping[: len(ids_dims_mapping)]:
沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
319 320 321 322
            return False

        return True

323
    def update_dims_mapping(self, dist_op):
324
        changed = False
325 326
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
327 328 329 330 331 332 333 334 335
        ids_name = op_desc.input('Ids')[0]
        w_name = op_desc.input('W')[0]
        out_name = op_desc.output('Out')[0]
        ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
        w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        for i in range(len(ids_dims_mapping)):
            dim_changed = compute_compatible_and_update_dim_mapping(
J
JZ-LIANG 已提交
336 337
                [ids_dims_mapping, out_dims_mapping], [i, i]
            )
338 339 340 341
            if dim_changed:
                changed = True

        dim_changed = compute_compatible_and_update_dim_mapping(
J
JZ-LIANG 已提交
342 343
            [w_dims_mapping, out_dims_mapping], [-1, -1]
        )
344 345 346 347 348
        if dim_changed:
            changed = True

        return changed

349 350 351 352 353 354
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

355
        dist_op_context = ctx.dist_op_context
356 357 358 359
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
360
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
J
JZ-LIANG 已提交
361 362 363
        assert (
            op_dist_attr is not None
        ), "forward op [{}] don't have dist attribute !".format(str(src_op))
364

365
        # check validation of inputs / outputs
366 367 368 369
        assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
        assert 'W' in kwargs, "input [{}] is not given".format('W')
        assert 'Out' in kwargs, "output [{}] is not given".format('Out')

J
JZ-LIANG 已提交
370 371 372
        assert (
            len(kwargs['Ids']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
373
            kwargs['Ids']
J
JZ-LIANG 已提交
374 375 376 377
        )
        assert (
            len(kwargs['W']) == 1
        ), "row_parallel_embedding input W take 1 variable but got {}".format(
378
            kwargs['W']
J
JZ-LIANG 已提交
379 380 381 382
        )
        assert (
            len(kwargs['Out']) == 1
        ), "row_parallel_embedding output Out take 1 variable but got {}".format(
383
            kwargs['Out']
J
JZ-LIANG 已提交
384
        )
385 386

        Ids_var = main_block.var(kwargs['Ids'][0])
387
        Weight_var = main_block._var_recursive(kwargs['W'][0])
388 389
        Out_var = main_block.var(kwargs['Out'][0])

390 391 392 393
        # support lookup_table_v1
        if src_op.type == 'lookup_table':
            Ids_var = adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var)

394 395
        # got dist attribute info
        embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
396 397 398 399 400 401 402
            Weight_var.name
        )[0]
        assert (
            embedding_row_dim_mapping >= 0
        ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
            embedding_row_dim_mapping
        )
403 404
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
405 406 407

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
        if rank_id not in process_mesh_group:
J
JZ-LIANG 已提交
408 409 410
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
411 412

        # A generalized method to caculate embedding offset using cartisian product
J
JZ-LIANG 已提交
413 414 415 416 417 418
        relative_idx = _get_idx_in_axis(
            process_mesh_group,
            process_mesh_shape,
            embedding_row_dim_mapping,
            rank_id,
        )
419 420 421 422

        per_part_size = Weight_var.shape[0]
        relative_idx = relative_idx * per_part_size

423
        # TODO caculate ring id
424
        parallel_axis = embedding_row_dim_mapping
J
JZ-LIANG 已提交
425 426 427
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
428 429 430
        group = new_process_group(group_ranks)

        # append op
J
JZ-LIANG 已提交
431 432 433
        check_variable_and_dtype(
            Ids_var, 'input', ['int32', 'int64'], 'c_embedding'
        )
434

Z
zhaoyingli 已提交
435 436 437 438 439
        # infer new var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
J
JZ-LIANG 已提交
440 441 442
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
443

444
        intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
445 446 447
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_embedding", 'tmp'])
            ),
448 449 450 451
            dtype=Weight_var.dtype,
            shape=Out_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
J
JZ-LIANG 已提交
452 453
            stop_gradient=Out_var.stop_gradient,
        )
Z
zhaoyingli 已提交
454
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
J
JZ-LIANG 已提交
455 456 457
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
458 459

        check_variable_and_dtype(
J
JZ-LIANG 已提交
460 461
            Out_var,
            'tensor',
462
            ['float16', 'float32', 'float64', 'int32', 'int64'],
J
JZ-LIANG 已提交
463 464
            'c_allreduce_sum',
        )
465 466 467

        c_embedding_op = main_block.append_op(
            type='c_embedding',
J
JZ-LIANG 已提交
468
            inputs={'Ids': [Ids_var], 'W': [Weight_var]},
469
            outputs={'Out': [intermediate_var_0]},
470 471
            attrs={
                "start_index": relative_idx,
J
JZ-LIANG 已提交
472 473 474
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
475 476
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
477 478 479 480 481 482 483 484 485 486

        # use_model_parallel
        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': [intermediate_var_0]},
            outputs={'Out': [Out_var]},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
J
JZ-LIANG 已提交
487 488 489
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
490 491 492 493 494 495 496
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        embedding_op_dist_attr = OperatorDistributedAttribute()
        embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh
497
        embedding_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
498 499 500 501
        embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_embedding_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
502 503 504 505 506
                op_dist_attr
            )
            embedding_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
507 508 509
        output_varname = c_embedding_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
510 511 512 513 514
            op_dist_attr
        )
        embedding_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
515
        ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr)
516

Z
zhaoyingli 已提交
517 518 519
        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
520
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
521 522 523 524 525
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
            input_var = main_block.var(input_varname)
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
J
JZ-LIANG 已提交
526 527 528
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
529 530 531
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
532 533 534 535 536 537 538 539
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
540 541

        # param initialization sync
542
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
543 544
            if Weight_var.name in dist_op_context.already_init_sync_vars:
                return
J
JZ-LIANG 已提交
545 546 547 548 549 550 551 552 553 554 555
            dist_op_context.already_init_sync_vars.add(Weight_var.name)
            param = startup_block.var(Weight_var.name)
            param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
            process_mesh = param_dist_attr.process_mesh
            dim_mapping = param_dist_attr.dims_mapping

            # NOTE all not splited axis should be presented in mesh
            for axis, size in enumerate(process_mesh.topology):
                if size <= 1 or axis in dim_mapping:
                    pass
                else:
J
JZ-LIANG 已提交
556 557 558 559 560 561
                    group_ranks = _get_comm_group(
                        process_mesh.processes,
                        process_mesh.topology,
                        axis,
                        rank_id,
                    )
J
JZ-LIANG 已提交
562 563
                    sync_group = new_process_group(group_ranks)

J
JZ-LIANG 已提交
564 565 566 567 568 569 570 571 572 573 574
                    startup_block.append_op(
                        type='c_broadcast',
                        inputs={'X': param},
                        outputs={'Out': param},
                        attrs={
                            'ring_id': sync_group.id,
                            'root': 0,
                            'use_calc_stream': True,
                            OP_ROLE_KEY: OpRole.Forward,
                        },
                    )
575 576 577 578 579

    @staticmethod
    def backward(ctx, *args, **kwargs):

        # by now the backward function only insert the gradient allreduce for dist op itself
580
        dist_op_context = ctx.dist_op_context
581 582 583
        main_block = dist_op_context.work_block
        backward_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
584
        dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
J
JZ-LIANG 已提交
585 586 587 588 589
        assert (
            dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(
            str(backward_op)
        )
590

591
        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
592
        if rank_id not in dist_attr.process_mesh.processes:
J
JZ-LIANG 已提交
593 594 595
            rank_id = _get_corresponding_rank(
                ctx, dist_attr.process_mesh, rank_id
            )
596 597 598 599 600 601

        assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
        assert 'W' in kwargs, "input [{}] is not given".format('W')
        assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out')
        assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD')

J
JZ-LIANG 已提交
602 603 604
        assert (
            len(kwargs['Ids']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
605
            kwargs['Ids']
J
JZ-LIANG 已提交
606 607 608 609
        )
        assert (
            len(kwargs['W']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
610
            kwargs['W']
J
JZ-LIANG 已提交
611 612 613 614 615 616 617 618 619
        )
        assert (
            len(kwargs['Out@GRAD']) == 1
        ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
            kwargs['Out']
        )
        assert (
            len(kwargs['W@GRAD']) == 1
        ), "row_parallel_embedding output Ids take 1 variable but got {}".format(
620
            kwargs['W@GRAD']
J
JZ-LIANG 已提交
621
        )
622 623

        Ids_var = main_block.var(kwargs['Ids'][0])
624 625 626 627 628
        Weight_var = main_block.var(kwargs['W'][0])
        Out_grad = main_block.var(kwargs['Out@GRAD'][0])
        Weight_grad = main_block.var(kwargs['W@GRAD'][0])

        embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
629 630 631 632 633 634 635
            Weight_var.name
        )[0]
        assert (
            embedding_row_dim_mapping >= 0
        ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
            embedding_row_dim_mapping
        )
636 637 638 639
        process_mesh_shape = dist_attr.process_mesh.topology
        process_mesh_group = dist_attr.process_mesh.processes

        # A generalized method to caculate embedding offset using cartisian product
J
JZ-LIANG 已提交
640 641 642 643 644 645
        relative_idx = _get_idx_in_axis(
            process_mesh_group,
            process_mesh_shape,
            embedding_row_dim_mapping,
            rank_id,
        )
646 647 648 649
        per_part_size = Weight_var.shape[0]
        relative_idx = relative_idx * per_part_size

        check_variable_and_dtype(
J
JZ-LIANG 已提交
650 651 652 653 654
            Out_grad,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            '_c_identity',
        )
655 656

        intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
657 658 659
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_embedding", '@tmp_0@GRAD'])
            ),
660 661 662 663
            dtype=Out_grad.dtype,
            shape=Out_grad.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
J
JZ-LIANG 已提交
664 665
            stop_gradient=Out_grad.stop_gradient,
        )
666 667 668 669

        # copy X_var's dist_attr to intermediate_var_0's dist_attr
        out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
        assert out_grad_dist_attr is not None
J
JZ-LIANG 已提交
670 671 672
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_grad_dist_attr
        )
673

J
JZ-LIANG 已提交
674 675 676 677 678 679
        group_ranks = _get_comm_group(
            process_mesh_group,
            process_mesh_shape,
            embedding_row_dim_mapping,
            rank_id,
        )
680 681 682 683 684 685 686 687 688 689 690
        group = new_process_group(group_ranks)

        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [Out_grad]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
                OP_ROLE_KEY: OpRole.Backward,
J
JZ-LIANG 已提交
691 692 693 694 695 696 697 698 699 700 701
            },
        )
        check_variable_and_dtype(
            intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64'],
            'linear',
        )
702

J
JZ-LIANG 已提交
703 704 705
        set_comm_op_dist_attr_for_program(
            c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
        )
706

707
        c_embedding_grad_op_desc = main_block.append_op(type='nop').desc
708 709 710
        c_embedding_grad_op_desc.set_type("c_embedding_grad")
        c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
        c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
J
JZ-LIANG 已提交
711 712 713
        c_embedding_grad_op_desc.set_input(
            'Out@GRAD', [intermediate_var_0.name]
        )
714 715 716 717 718 719
        c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
        c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
        c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)

        c_embedding_grad_op = main_block.ops[-1]
        assert c_embedding_grad_op.type == "c_embedding_grad"
J
JZ-LIANG 已提交
720 721 722
        naive_copy_op_dist_attr_for_program(
            c_embedding_grad_op, backward_op, ctx
        )
723

724 725 726
        # data parallel gradient synchronization
        act_grad_names = [Ids_var.name]
        out_grad_names = [kwargs['W@GRAD'][0]]
727

J
JZ-LIANG 已提交
728 729 730
        gradient_synchronization(
            ctx, backward_op, act_grad_names, out_grad_names, rank_id
        )
731

732

J
JZ-LIANG 已提交
733 734 735 736 737 738 739 740 741
register_distributed_operator_impl(
    "lookup_table_v2", DistributedEmbeddingImpl("row_parallel")
)
register_distributed_operator_impl(
    "c_embedding", DistributedEmbeddingImpl("row_parallel")
)
register_distributed_operator_impl(
    "lookup_table", DistributedEmbeddingImpl("row_parallel")
)