dist_matmul.py 118.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
import copy
C
caozhou 已提交
16

17
from paddle.common_ops_import import check_dtype, check_variable_and_dtype
18 19 20 21
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
    AllreduceSumOpCost,
    IdentityOpCost,
)
22
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
23 24
from paddle.framework import core
from paddle.utils import unique_name
25

26
from ..cost import (
27 28 29 30 31 32 33
    MatmulGradOpCost,
    MatmulOpCost,
    MatmulV2GradOpCost,
    MatmulV2OpCost,
    MulGradOpCost,
    MulOpCost,
    build_comm_costs_from_descs,
34
    build_comm_desc_from_dist_op,
35 36
    build_comp_costs_from_descs,
    build_comp_desc_from_dist_op,
37 38
    build_dp_costs,
)
39
from ..dist_attribute import OperatorDistAttr
40 41 42 43 44 45 46 47 48 49
from ..process_group import new_process_group
from ..utils import (
    _get_comm_group,
    _get_corresponding_rank,
    compute_compatible_and_update_dim_mapping,
    compute_compatible_dims_mapping,
    is_dim_replicate,
    is_dim_shard,
    is_valid_list_index,
    set_dist_op_desc_original_id,
50
)
51 52 53 54 55 56 57 58 59 60 61
from .common import (
    DistributedOperatorImpl,
    DistributedOperatorImplContainer,
    gradient_synchronization,
    infer_shape,
    is_parameter_related,
    register_distributed_operator_impl,
    register_distributed_operator_impl_container,
    set_comm_op_dist_attr_for_program,
)
from .dist_default import DistributedDefaultImpl0
62 63


64 65
def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping):
    if trans_x:
66 67 68 69
        x_dims_mapping[-1], x_dims_mapping[-2] = (
            x_dims_mapping[-2],
            x_dims_mapping[-1],
        )
70
    if trans_y:
71 72 73 74
        y_dims_mapping[-1], y_dims_mapping[-2] = (
            y_dims_mapping[-2],
            y_dims_mapping[-1],
        )
75 76


77
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
78 79 80 81 82 83
    pass

    src_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
    dist_attr = copy.deepcopy(src_dist_attr)
    dist_op = block.append_op(type='nop')
    dist_op_desc = dist_op.desc
84
    dist_op_desc.copy_from(src_op.desc)
85
    set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
86 87 88
    for input_name in src_op.desc.input_names():
        assert input_name in kwargs
        dist_op_desc.set_input(input_name, kwargs[input_name])
89 90 91
        dist_attr.rename_input(
            src_op.desc.input(input_name)[0], kwargs[input_name][0]
        )
92
    for output_name in src_op.desc.output_names():
93
        assert output_name in kwargs
94
        dist_op_desc.set_output(output_name, kwargs[output_name])
95 96 97 98 99
        dist_attr.rename_output(
            src_op.desc.output(output_name)[0], kwargs[output_name][0]
        )
    # TODO: this call leads to a deepcopy when we init the dist op
    ctx.set_op_dist_attr_for_program(dist_op, dist_attr)
100 101 102 103

    return dist_op_desc


104
def _update_dims_mapping_for_matmul(dist_op):
105
    changed = False
106 107
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
108 109 110
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
C
caozhou 已提交
111 112 113 114 115 116 117 118
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
119 120 121 122 123 124 125 126 127
    x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
    y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
    out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
C
caozhou 已提交
128
        assert trans_x is False
129
        x_dims_mapping.insert(0, -1)
C
caozhou 已提交
130
        out_dims_mapping.insert(out_dims_mapping_len - 1, 0)
131
    if y_dims_mapping_len == 1:
C
caozhou 已提交
132
        assert trans_y is False
133
        y_dims_mapping.insert(1, -1)
C
caozhou 已提交
134
        out_dims_mapping.insert(out_dims_mapping_len, 0)
135

136 137
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)

C
caozhou 已提交
138 139 140
    new_x_dims_mapping_len = len(x_dims_mapping)
    new_y_dims_mapping_len = len(y_dims_mapping)
    new_out_dims_mapping_len = len(out_dims_mapping)
141
    # Deal with dim > 2 and take care of broadcasting
C
caozhou 已提交
142
    if new_out_dims_mapping_len > 2:
143 144 145 146
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

C
caozhou 已提交
147
        for i in range(new_out_dims_mapping_len - new_x_dims_mapping_len):
148
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
149
        for i in range(new_x_dims_mapping_len - 2):
150 151
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

C
caozhou 已提交
152
        for i in range(new_out_dims_mapping_len - new_y_dims_mapping_len):
153
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
154
        for i in range(new_y_dims_mapping_len - 2):
155 156
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

C
caozhou 已提交
157
        for i in range(new_out_dims_mapping_len - 2):
158 159
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

160 161 162 163 164 165 166
        compatible_dims_mapping = compute_compatible_dims_mapping(
            [
                broadcast_x_dims_mapping,
                broadcast_y_dims_mapping,
                broadcast_out_dims_mapping,
            ]
        )
167
        if compatible_dims_mapping is None:
168 169 170
            trans_x_y_dims_mapping(
                trans_x, trans_y, x_dims_mapping, y_dims_mapping
            )
171
            return False
172

C
caozhou 已提交
173 174
        for i in range(new_x_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_x_dims_mapping_len)
175 176 177 178
            if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                x_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
179 180
        for i in range(new_y_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_y_dims_mapping_len)
181 182 183 184
            if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                y_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
185
        for i in range(new_out_dims_mapping_len - 2):
186 187 188 189
            if out_dims_mapping[i] != compatible_dims_mapping[i]:
                out_dims_mapping[i] = compatible_dims_mapping[i]
                changed = True

190
    # The following which uses negative index can be work
191 192
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
    dim_changed = compute_compatible_and_update_dim_mapping(
193 194
        [x_dims_mapping, y_dims_mapping], [-1, -2]
    )
195 196 197 198
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
199 200
        [x_dims_mapping, out_dims_mapping], [-2, -2]
    )
201 202 203 204
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
205 206
        [y_dims_mapping, out_dims_mapping], [-1, -1]
    )
207 208 209
    if dim_changed:
        changed = True

210
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
C
caozhou 已提交
211

212
    # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
213 214
    if x_dims_mapping_len == 1:
        x_dims_mapping.pop(0)
C
caozhou 已提交
215
        out_dims_mapping.pop(out_dims_mapping_len - 1)
216 217
    if y_dims_mapping_len == 1:
        y_dims_mapping.pop(1)
C
caozhou 已提交
218
        out_dims_mapping.pop(out_dims_mapping_len)
219 220 221 222 223

    assert len(x_dims_mapping) == x_dims_mapping_len
    assert len(y_dims_mapping) == y_dims_mapping_len
    assert len(out_dims_mapping) == out_dims_mapping_len

224 225 226 227 228
    if changed:
        op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
        op_dist_attr.set_input_dims_mapping(y_name, y_dims_mapping)
        op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)

229 230 231
    return changed


232 233 234 235 236 237
def _is_auto_compatible_for_matmul(dist_op):
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
238 239 240 241 242 243 244 245 246
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')

247 248 249 250
    # Deep copy these dims_mappings for keeping them unchanged.
    x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
    y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
    out_dims_mapping = copy.deepcopy(
251 252
        op_dist_attr.get_output_dims_mapping(out_name)
    )
253 254 255 256 257 258 259 260 261 262
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
        x_dims_mapping.insert(0, -1)
    if y_dims_mapping_len == 1:
        y_dims_mapping.insert(1, -1)

263
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
264

265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
    # Deal with dim > 2 and take care of broadcasting
    if out_dims_mapping_len > 2:
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

        for i in range(out_dims_mapping_len - x_dims_mapping_len):
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
        for i in range(x_dims_mapping_len - 2):
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

        for i in range(out_dims_mapping_len - y_dims_mapping_len):
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
        for i in range(y_dims_mapping_len - 2):
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

        for i in range(out_dims_mapping_len - 2):
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

284 285 286
        is_same = (broadcast_x_dims_mapping == broadcast_y_dims_mapping) and (
            broadcast_x_dims_mapping == broadcast_out_dims_mapping
        )
287 288 289 290 291
        if not is_same:
            return False

    # The following which uses negative index can be work
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
292
    is_same = x_dims_mapping[-1] == y_dims_mapping[-2]
293 294 295
    if not is_same:
        return False

296
    is_same = x_dims_mapping[-2] == out_dims_mapping[-2]
297 298 299
    if not is_same:
        return False

300
    is_same = y_dims_mapping[-1] == out_dims_mapping[-1]
301 302 303 304 305 306
    if not is_same:
        return False

    return True


307 308 309 310
def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):

    # by now the backward function only insert the gradient allreduce for dist op itself

311
    dist_op_context = ctx.dist_op_context
312 313 314
    main_block = dist_op_context.work_block
    backward_op = dist_op_context.cur_src_op
    rank_id = dist_op_context.rank_id
315
    dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
316 317 318
    assert (
        dist_attr is not None
    ), "backward op [{}] don't have dist attribute !".format(str(backward_op))
319 320

    # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
321
    if rank_id not in dist_attr.process_mesh.process_ids:
322
        rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
323 324 325 326 327 328

    assert 'Y' in kwargs, "input [{}] is not given".format('Y')
    assert 'X' in kwargs, "input [{}] is not given".format('X')
    assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
    assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
    assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
329 330 331
    assert (
        len(kwargs['Y']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
332
        kwargs['Y']
333 334 335 336
    )
    assert (
        len(kwargs['X']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
337
        kwargs['X']
338 339 340 341 342 343 344 345 346
    )
    assert (
        len(kwargs['Out@GRAD']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
        kwargs['Out']
    )
    assert (
        len(kwargs['Y@GRAD']) == 1
    ), "row_parallel_embedding output Ids take 1 variable but got {}".format(
347
        kwargs['Y@GRAD']
348
    )
349

Z
zhaoyingli 已提交
350
    X_var = main_block._var_recursive(kwargs['X'][0])
351
    Y_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
352 353
    Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0])
    Y_grad = main_block._var_recursive(kwargs['Y@GRAD'][0])
354

J
JZ-LIANG 已提交
355 356 357
    assert not is_parameter_related(
        X_var.name, main_block
    ), "left operand(X) [{}] of dist matmul should not be parameter".format(
358 359
        X_var.name
    )
360

361
    X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
362
    Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
363 364
    process_mesh_shape = dist_attr.process_mesh.shape
    process_mesh_group = dist_attr.process_mesh.process_ids
365 366 367 368 369 370 371 372 373 374 375 376 377

    trans_x = None
    trans_y = None
    if backward_op.desc.type() == "matmul_v2_grad":
        trans_x = backward_op.desc.attr('trans_x')
        trans_y = backward_op.desc.attr('trans_y')
    elif backward_op.desc.type() == "matmul_grad":
        trans_x = backward_op.desc.attr('transpose_X')
        trans_y = backward_op.desc.attr('transpose_Y')

    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

378 379 380 381
    # assert len(
    #     Y_var_dim_mapping
    # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
    #     Y_var.name, Y_var_dim_mapping)
382 383 384 385 386 387
    Y_var_partitioned = False
    for dim in Y_var_dim_mapping:
        if dim >= 0 and process_mesh_shape[dim] > 0:
            Y_var_partitioned = True
            break

J
JZ-LIANG 已提交
388
    if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:
389 390 391 392 393 394 395

        if Y_var_dim_mapping[0] >= 0:
            # row parallel: c_identity + matmul
            assert Y_var_dim_mapping[1] < 0
            parallel_axis = Y_var_dim_mapping[0]

            check_variable_and_dtype(
396 397
                Out_grad,
                'tensor',
X
xu98bin 已提交
398
                ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
399 400
                '_c_identity',
            )
401 402

            intermediate_var_0 = main_block.create_var(
403 404 405 406
                name=unique_name.generate_with_ignorable_key(
                    ".".join(["c_identity", 'tmp'])
                )
                + "@GRAD",
407 408 409 410
                dtype=Out_grad.dtype,
                shape=Out_grad.shape,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
411 412
                stop_gradient=Out_grad.stop_gradient,
            )
413 414 415 416

            # copy X_var's dist_attr to intermediate_var_0's dist_attr
            out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
            assert out_grad_dist_attr is not None
417 418 419
            ctx.set_tensor_dist_attr_for_program(
                intermediate_var_0, out_grad_dist_attr
            )
420

421 422 423
            group_ranks = _get_comm_group(
                process_mesh_group, process_mesh_shape, parallel_axis, rank_id
            )
424 425 426 427 428 429 430 431 432 433
            group = new_process_group(group_ranks)
            c_identity_op = main_block.append_op(
                type='c_identity',
                inputs={'X': [Out_grad]},
                outputs={'Out': intermediate_var_0},
                attrs={
                    'ring_id': group.id,
                    'use_calc_stream': True,
                    'use_model_parallel': True,
                    OP_ROLE_KEY: OpRole.Backward,
434 435 436 437 438
                },
            )
            check_variable_and_dtype(
                intermediate_var_0,
                'x',
X
xu98bin 已提交
439
                ['float16', 'float32', 'float64', 'uint16'],
440 441 442 443 444
                'linear',
            )
            check_dtype(
                intermediate_var_0.dtype,
                'dtype',
X
xu98bin 已提交
445
                ['float16', 'float32', 'float64', 'uint16'],
446 447 448 449 450
                'linear',
            )
            set_comm_op_dist_attr_for_program(
                c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
            )
451 452 453 454

            new_kwargs = copy.deepcopy(kwargs)
            new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
            matmul_op_desc = copy_op_with_new_input_output(
455 456
                ctx, main_block, backward_op, **new_kwargs
            )
457 458 459 460 461 462 463 464 465 466
        else:
            # col parallel: matmul + allreduce
            assert Y_var_dim_mapping[0] < 0
            parallel_axis = Y_var_dim_mapping[1]
            new_kwargs = copy.deepcopy(kwargs)

            # NOTE (JZ-LIANG) should allow left operand be empty for matmul grad
            has_x_grad = len(kwargs['X@GRAD']) > 0
            if has_x_grad:
                assert len(kwargs['X@GRAD']) == 1
Z
zhaoyingli 已提交
467
                X_grad = main_block._var_recursive(kwargs['X@GRAD'][0])
468
                intermediate_var_0 = main_block.create_var(
469 470 471 472
                    name=unique_name.generate_with_ignorable_key(
                        ".".join(["c_identity", 'tmp'])
                    )
                    + "@GRAD",
473 474 475 476
                    dtype=X_grad.dtype,
                    shape=X_grad.shape,
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    persistable=False,
477 478
                    stop_gradient=X_grad.stop_gradient,
                )
479 480 481

                X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
                assert X_grad_dist_attr is not None
482 483 484
                ctx.set_tensor_dist_attr_for_program(
                    intermediate_var_0, X_grad_dist_attr
                )
485 486 487
                new_kwargs['X@GRAD'] = [intermediate_var_0.name]

            matmul_op_desc = copy_op_with_new_input_output(
488 489
                ctx, main_block, backward_op, **new_kwargs
            )
490 491 492

            # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
            if has_x_grad:
493 494 495 496 497 498
                group_ranks = _get_comm_group(
                    process_mesh_group,
                    process_mesh_shape,
                    parallel_axis,
                    rank_id,
                )
499 500 501 502 503 504 505 506 507
                group = new_process_group(group_ranks)
                c_allreduce_sum_op = main_block.append_op(
                    type='c_allreduce_sum',
                    inputs={'X': [intermediate_var_0.name]},
                    outputs={'Out': kwargs['X@GRAD']},
                    attrs={
                        'ring_id': group.id,
                        'use_calc_stream': True,
                        'use_model_parallel': True,
508 509 510 511 512 513 514 515 516
                        OP_ROLE_KEY: OpRole.Backward,
                    },
                )
                set_comm_op_dist_attr_for_program(
                    c_allreduce_sum_op,
                    dist_attr.process_mesh,
                    X_grad_dist_attr,
                    ctx,
                )
517 518
    else:
        # replicate
519 520 521
        matmul_op_desc = copy_op_with_new_input_output(
            ctx, main_block, backward_op, **kwargs
        )
522

523 524 525 526 527 528 529
    # data parallel gradient synchronization
    act_grad_names = [X_var.name]

    out_grad_names = []
    if is_parameter_related(Y_var.name, main_block):
        out_grad_names = [kwargs['Y@GRAD'][0]]

530 531 532
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)

533 534 535
    gradient_synchronization(
        ctx, backward_op, act_grad_names, out_grad_names, rank_id
    )
536

537 538 539 540 541
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

542

543
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
544

545 546
    if Weight_var.name in dist_op_context.already_init_sync_vars:
        return
547
    assert startup_block.has_var(Weight_var.name)
548
    dist_op_context.already_init_sync_vars.add(Weight_var.name)
549
    param = startup_block.var(Weight_var.name)
550 551 552
    param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
    process_mesh = param_dist_attr.process_mesh
    dim_mapping = param_dist_attr.dims_mapping
553

554
    for axis, size in enumerate(process_mesh.shape):
555 556 557
        if size <= 1 or axis in dim_mapping:
            pass
        else:
558
            group_ranks = _get_comm_group(
559
                process_mesh.process_ids, process_mesh.shape, axis, rank_id
560
            )
561 562
            sync_group = new_process_group(group_ranks)

563 564 565 566 567 568 569 570 571 572 573
            startup_block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': sync_group.id,
                    'root': 0,
                    'use_calc_stream': True,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
574 575


576
class DistributedMatmul(DistributedOperatorImplContainer):
577
    def __init__(self, op_type):
578
        super().__init__(op_type)
579 580


581
register_distributed_operator_impl_container(DistributedMatmul("matmul"))
582 583 584 585 586


# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
587
        super().__init__(name)
588
        self._forward_implemented = True
589
        self._backward_implemented = True
590

C
caozhou 已提交
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
607 608
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
609 610 611 612 613 614 615 616 617
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
618 619 620
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
621
        process_mesh = dist_attr.process_mesh
622
        processes = process_mesh.process_ids
623 624 625
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
626 627 628 629 630 631 632 633 634 635 636 637
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
638 639
                parallel_axis=parallel_axis,
            )
C
caozhou 已提交
640
            comm_op_cost_list = build_comm_costs_from_descs(
641 642 643 644 645 646
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
C
caozhou 已提交
647 648 649 650
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
651 652
            backward_op.input("X")[0]
        )
653
        mesh_shape = process_mesh.shape
C
caozhou 已提交
654
        batch_size_axis = var_dim_mapping[0]
655 656 657 658 659
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
660 661 662
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
663 664 665
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
666 667 668 669
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
670 671 672
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
673
        processes = dist_op.dist_attr.process_mesh.process_ids
674 675 676
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
677 678 679 680

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
681 682
            serial_op.input("Y")[0]
        )[-1]
C
caozhou 已提交
683 684 685 686 687 688 689 690
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
691 692
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
693 694

        comm_op_cost_list = build_comm_costs_from_descs(
695 696
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
697 698 699 700
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

701 702 703
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
704 705
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
706
        x_dims_mapping = copy.deepcopy(
707 708
            op_dist_attr.get_input_dims_mapping(x_name)
        )
709
        y_dims_mapping = copy.deepcopy(
710 711
            op_dist_attr.get_input_dims_mapping(y_name)
        )
712 713 714
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
715 716
        if is_dim_shard(x_dims_mapping[-1]):
            return False
717
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
718 719
            y_dims_mapping[-1]
        ):
720 721 722 723 724 725
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

726 727 728
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
729 730 731 732 733 734 735 736 737
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

738
    def is_auto_compatible(self, dist_op):
739 740 741
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
742
            return False
743
        if not _is_auto_compatible_for_matmul(dist_op):
744 745 746
            return False
        return True

747
    def update_dims_mapping(self, dist_op):
748
        changed = False
749
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
750 751 752 753
        if dim_changed:
            changed = True
        return changed

754 755 756 757 758 759
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

760
        dist_op_context = ctx.dist_op_context
761 762 763 764
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
765
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
766 767 768
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
769 770

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
771
        if rank_id not in op_dist_attr.process_mesh.process_ids:
772 773 774
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
775

776
        # check validation of inputs / outputs
777 778
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
779 780
                input_name
            )
781 782 783 784 785
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
786 787
                output_name
            )
788 789 790
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
791 792
                output_name
            )
793

Z
zhaoyingli 已提交
794 795 796
        X_var = main_block._var_recursive(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block._var_recursive(kwargs['Out'][0])
797 798
        trans_x = src_op.attr("transpose_X")
        trans_y = src_op.attr("transpose_Y")
799 800 801

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
802 803
            Weight_var.name
        )[-1]
804 805
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
806 807 808 809 810 811 812
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
813 814
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
815 816

        parallel_axis = matmul_col_dim_mapping
817 818 819
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
820 821
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
822 823 824 825 826
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
827 828 829
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
830 831 832 833 834
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
835 836 837
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
838

839
        intermediate_var_0 = main_block.create_var(
840 841 842
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
843 844 845 846
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
847 848
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
849
        # set intermediate_var_0's dist_attr with X_var's dist_attr
850 851 852
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
853 854

        check_variable_and_dtype(
855 856
            X_var,
            'tensor',
X
xu98bin 已提交
857
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
858 859
            '_c_identity',
        )
860 861 862 863 864 865 866 867 868

        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
869 870 871
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
872 873
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
874

875
        check_variable_and_dtype(
X
xu98bin 已提交
876 877 878 879
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
880 881 882 883
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
X
xu98bin 已提交
884
            ['float16', 'float32', 'float64', 'uint16'],
885 886
            'linear',
        )
887
        attrs = {
888 889
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
890
            'alpha': 1,
891
            OP_ROLE_KEY: src_op.attr('op_role'),
892 893
        }
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
894 895 896
        matmul_op = main_block.append_op(
            type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
Z
zhaoyingli 已提交
897 898 899 900 901
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
902
        identity_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
903
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
904
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
905 906 907 908 909
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
910 911 912 913 914
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
915 916
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
917 918 919
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
920 921 922 923
        # set op dist attr
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmul
924
        matmul_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
925
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
926
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
927 928 929 930 931
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        for input_varname in matmul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
932 933
                    input_varname
                )
Z
zhaoyingli 已提交
934
                assert input_dist_attr is not None, "dist_attr is {}".format(
935 936 937 938 939
                    op_dist_attr
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
940
            else:
Z
zhaoyingli 已提交
941
                input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
942
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
943 944 945 946 947
                    input_var
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
948 949 950 951
        # output
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
        assert output_dist_attr is not None, "dist_attr is {}".format(
952 953 954 955 956
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
957 958
        # set op dist attr
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
959 960

        # init param sync
961
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
962 963 964
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
965 966 967 968

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
969

970 971 972 973

# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
974
        super().__init__(name)
975
        self._forward_implemented = True
976
        self._backward_implemented = True
977

C
caozhou 已提交
978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
994 995
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1008 1009
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
1010
        process_mesh = dist_attr.process_mesh
1011
        processes = process_mesh.process_ids
C
caozhou 已提交
1012
        comm_op_cost_list = build_comm_costs_from_descs(
1013 1014
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
1015 1016 1017
        res.append(comm_op_cost_list)

        # calc comp op cost
1018 1019 1020 1021 1022 1023
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1024 1025 1026 1027
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1028 1029
            backward_op.input("X")[0]
        )
1030
        mesh_shape = process_mesh.shape
C
caozhou 已提交
1031
        batch_size_axis = var_dim_mapping[0]
1032 1033 1034 1035 1036
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1037 1038 1039
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1040 1041 1042
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1043 1044 1045 1046
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1047 1048 1049
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1050
        processes = dist_op.dist_attr.process_mesh.process_ids
1051 1052 1053
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1054 1055 1056 1057

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1058 1059
            serial_op.input("Y")[0]
        )[-2]
C
caozhou 已提交
1060 1061 1062 1063 1064 1065 1066 1067 1068
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1069 1070
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
1071 1072

        comm_op_cost_list = build_comm_costs_from_descs(
1073 1074 1075 1076 1077 1078
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
C
caozhou 已提交
1079 1080 1081 1082 1083

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1084 1085 1086
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1087 1088
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1089
        x_dims_mapping = copy.deepcopy(
1090 1091
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1092
        y_dims_mapping = copy.deepcopy(
1093 1094
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1095 1096 1097
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1098 1099
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1100
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
1101 1102
            y_dims_mapping[-1]
        ):
1103 1104 1105 1106 1107 1108 1109
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1110 1111 1112
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1123
    def is_auto_compatible(self, dist_op):
1124 1125 1126
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1127
            return False
1128
        if not _is_auto_compatible_for_matmul(dist_op):
1129 1130 1131
            return False
        return True

1132
    def update_dims_mapping(self, dist_op):
1133
        changed = False
1134
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1135 1136 1137 1138
        if dim_changed:
            changed = True
        return changed

1139 1140 1141 1142 1143 1144
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1145
        dist_op_context = ctx.dist_op_context
1146 1147 1148 1149
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1150
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1151 1152 1153
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1154 1155

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1156
        if rank_id not in op_dist_attr.process_mesh.process_ids:
1157 1158 1159
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1160

1161
        # check validation of inputs / outputs
1162 1163
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
1164 1165
                input_name
            )
1166 1167 1168 1169 1170
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
1171 1172
                output_name
            )
1173 1174 1175
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
1176 1177
                output_name
            )
1178

Z
zhaoyingli 已提交
1179 1180 1181
        X_var = main_block._var_recursive(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block._var_recursive(kwargs['Out'][0])
1182 1183
        trans_x = src_op.attr('transpose_X')
        trans_y = src_op.attr('transpose_Y')
1184 1185 1186

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
1187 1188
            Weight_var.name
        )[-2]
1189 1190
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
1191 1192 1193 1194 1195 1196 1197
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
1198 1199
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
1200 1201

        parallel_axis = matmul_row_dim_mapping
1202 1203 1204
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1205 1206
        group = new_process_group(group_ranks)

1207
        check_variable_and_dtype(
X
xu98bin 已提交
1208
            X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
1209 1210
        )
        check_dtype(
X
xu98bin 已提交
1211 1212 1213 1214
            X_var.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
1215
        )
1216
        attrs = {
1217 1218
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
1219
            'alpha': 1,
1220
            OP_ROLE_KEY: src_op.attr('op_role'),
1221 1222
        }
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
1223 1224 1225 1226 1227 1228

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
1229 1230 1231
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1232

1233
        intermediate_var_0 = main_block.create_var(
1234 1235 1236
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
1237 1238 1239 1240 1241 1242
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
1243 1244
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
1245
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
1246 1247 1248
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
1249

1250 1251 1252 1253 1254 1255
        matmul_op = main_block.append_op(
            type='matmul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1256 1257
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
1258 1259 1260 1261 1262 1263 1264 1265

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
1266
                'use_model_parallel': True,
1267 1268 1269
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
1270 1271 1272 1273 1274
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmul
1275
        matmul_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
1276
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1277
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1278 1279 1280 1281
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
1282 1283 1284 1285 1286
                op_dist_attr
            )
            matmul_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
1287 1288 1289
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
1290 1291 1292 1293 1294
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
1295 1296 1297
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)

        # allreduce
1298
        allreduce_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
1299
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1300
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1301 1302
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
1303
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
1304 1305
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
1306 1307 1308
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
1309 1310 1311
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
1312 1313 1314 1315 1316 1317 1318 1319
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
1320 1321

        # init param sync
1322
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1323 1324 1325
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1326 1327 1328 1329

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1330

1331

1332
# ReplicateParallel
1333 1334
class DistributedMatmulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
1335
        super().__init__(name)
1336

C
caozhou 已提交
1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block

        # calc comp op cost
1353 1354 1355
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
1356
        process_mesh = dist_attr.process_mesh
1357
        processes = process_mesh.process_ids
1358 1359 1360
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1361 1362 1363 1364
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1365 1366
            backward_op.input("X")[0]
        )
1367
        mesh_shape = process_mesh.shape
C
caozhou 已提交
1368
        batch_size_axis = var_dim_mapping[0]
1369 1370 1371 1372 1373
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1374 1375 1376
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1377 1378 1379
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1380 1381 1382 1383 1384

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1385 1386 1387
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1388
        processes = dist_op.dist_attr.process_mesh.process_ids
1389 1390 1391
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1392 1393 1394 1395

        res_cost = [cost_mapping]
        return res_cost

1396 1397 1398
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1399 1400 1401 1402 1403 1404 1405
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
1406
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
1407 1408
            x_dims_mapping[-2]
        ):
1409 1410 1411 1412
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
1413
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
1414 1415
            y_dims_mapping[-2]
        ):
1416 1417 1418 1419
            return False

        return True

1420 1421 1422
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1423 1424 1425 1426 1427
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
1428
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
1429 1430
            out_dims_mapping[-2]
        ):
1431 1432 1433 1434
            return False

        return True

1435
    def is_auto_compatible(self, dist_op):
1436 1437 1438
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1439 1440
            return False

1441
        if not _is_auto_compatible_for_matmul(dist_op):
1442 1443 1444 1445
            return False

        return True

1446
    def update_dims_mapping(self, dist_op):
1447
        changed = False
1448
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1449 1450 1451 1452
        if dim_changed:
            changed = True
        return changed

1453 1454 1455 1456
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

1457 1458 1459 1460
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

1461

1462 1463 1464 1465 1466 1467 1468 1469 1470
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl2("replicate_parallel")
)
1471 1472


1473
class DistributedMatmulV2(DistributedOperatorImplContainer):
1474
    def __init__(self, op_type):
1475
        super().__init__(op_type)
1476 1477


1478
register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
1479 1480


1481 1482 1483
# ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl):
    def __init__(self, name):
1484
        super().__init__(name)
1485
        self._forward_implemented = True
1486
        self._backward_implemented = True
1487

1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
1504 1505
            backward_op.input("Y")[0]
        )
1506
        process_mesh = dist_attr.process_mesh
1507
        processes = process_mesh.process_ids
1508
        # col parallel: matmul + allreduce
1509 1510
        if backward_op.attr("trans_y"):
            Y_var_dim_mapping.reverse()
1511 1512 1513 1514 1515 1516 1517 1518
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
1519 1520 1521
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1522

1523 1524 1525
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
1538 1539
                parallel_axis=parallel_axis,
            )
1540
            comm_op_cost_list = build_comm_costs_from_descs(
1541 1542 1543 1544 1545 1546
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
1547 1548 1549 1550 1551
            res.append(comm_op_cost_list)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1552 1553
            backward_op.input("X")[0]
        )
1554
        mesh_shape = process_mesh.shape
1555
        batch_size_axis = var_dim_mapping[0]
1556 1557 1558 1559 1560
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1561 1562 1563
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1564 1565 1566
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1567 1568 1569 1570 1571
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
        # TODO: trans shape if trans_x or trans_y is True
1572 1573 1574
        comp_desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1575
        processes = dist_op.dist_attr.process_mesh.process_ids
1576 1577 1578
        comp_cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster
        )
1579 1580 1581 1582

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1583 1584
            serial_op.input("Y")[0]
        )[-1]
1585 1586 1587 1588 1589 1590 1591 1592 1593
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1594 1595
            parallel_axis=parallel_axis,
        )
1596
        comm_op_cost_list = build_comm_costs_from_descs(
1597 1598
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1599 1600 1601 1602

        res_cost = [comm_op_cost_list, comp_cost_mapping]
        return res_cost

1603 1604 1605
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1606 1607
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1608
        x_dims_mapping = copy.deepcopy(
1609 1610
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1611
        y_dims_mapping = copy.deepcopy(
1612 1613
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1614 1615 1616
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1617 1618
        if is_dim_shard(x_dims_mapping[-1]):
            return False
1619
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
1620 1621
            y_dims_mapping[-1]
        ):
1622 1623 1624 1625 1626 1627
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1628 1629 1630
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1631 1632 1633 1634 1635 1636 1637 1638 1639
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1640
    def is_auto_compatible(self, dist_op):
1641 1642 1643
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1644
            return False
1645
        if not _is_auto_compatible_for_matmul(dist_op):
1646 1647 1648
            return False
        return True

1649
    def update_dims_mapping(self, dist_op):
1650
        changed = False
1651
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1652 1653 1654 1655
        if dim_changed:
            changed = True
        return changed

1656 1657 1658 1659 1660 1661
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1662
        dist_op_context = ctx.dist_op_context
1663 1664 1665 1666
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1667
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1668 1669 1670
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1671 1672

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1673
        if rank_id not in op_dist_attr.process_mesh.process_ids:
1674 1675 1676
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1677

1678
        # check validation of inputs / outputs
1679 1680
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
1681 1682
                input_name
            )
1683 1684 1685 1686 1687
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
1688 1689
                output_name
            )
1690 1691 1692
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
1693 1694
                output_name
            )
1695

Z
zhaoyingli 已提交
1696
        X_var = main_block._var_recursive(kwargs['X'][0])
1697
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
1698
        Out_var = main_block._var_recursive(kwargs['Out'][0])
1699 1700
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
1701 1702 1703

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
1704 1705
            Weight_var.name
        )[-1]
1706 1707
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
1708 1709 1710 1711 1712 1713 1714
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
1715 1716
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
1717 1718

        parallel_axis = matmul_col_dim_mapping
1719 1720 1721
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1722 1723
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
1724 1725 1726 1727 1728
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
1729 1730 1731
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
1732 1733 1734 1735 1736
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
1737 1738 1739
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1740

1741
        intermediate_var_0 = main_block.create_var(
1742 1743 1744
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
1745 1746 1747 1748
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
1749 1750
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
1751
        # set intermediate_var_0's dist_attr with X_var's dist_attr
1752 1753 1754
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
1755 1756

        check_variable_and_dtype(
1757 1758
            X_var,
            'tensor',
X
xu98bin 已提交
1759
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
1760 1761
            '_c_identity',
        )
1762 1763 1764 1765 1766 1767 1768 1769
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
1770
                OP_ROLE_KEY: src_op.attr('op_role'),
1771 1772
            },
        )
Z
zhaoyingli 已提交
1773 1774
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
1775

1776
        check_variable_and_dtype(
X
xu98bin 已提交
1777 1778 1779 1780
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
1781 1782 1783 1784
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
X
xu98bin 已提交
1785
            ['float16', 'float32', 'float64', 'uint16'],
1786 1787
            'linear',
        )
1788
        attrs = {
1789 1790
            'trans_x': trans_x,
            'trans_y': trans_y,
1791
            OP_ROLE_KEY: src_op.attr('op_role'),
1792
        }
1793
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
1794 1795 1796 1797 1798 1799
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': Out_var},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1800 1801 1802 1803 1804
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
1805
        identity_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
1806
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1807
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1808 1809 1810 1811 1812
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
1813 1814 1815 1816 1817
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1818 1819
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
1820 1821 1822
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1823 1824 1825
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
1826
        matmulv2_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
1827
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1828
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1829 1830 1831 1832
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
1833 1834
                    input_varname
                )
Z
zhaoyingli 已提交
1835
                assert input_dist_attr is not None, "dist_attr is {}".format(
1836 1837
                    op_dist_attr
                )
1838
                matmulv2_op_dist_attr.set_input_dist_attr(
1839 1840
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
1841
            else:
Z
zhaoyingli 已提交
1842
                input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
1843
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
1844 1845
                    input_var
                )
1846
                matmulv2_op_dist_attr.set_input_dist_attr(
1847 1848
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
1849 1850 1851
        for output_varname in matmul_v2_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
1852 1853 1854 1855 1856
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
Z
zhaoyingli 已提交
1857
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
1858 1859

        # init param sync
1860
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1861 1862 1863
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1864 1865 1866 1867

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1868 1869 1870 1871 1872


# RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl):
    def __init__(self, name):
1873
        super().__init__(name)
1874
        self._forward_implemented = True
1875
        self._backward_implemented = True
1876

1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
Z
zhaoyingli 已提交
1892

1893
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
1894 1895
            backward_op.input("Y")[0]
        )
1896 1897 1898 1899
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        process_mesh = dist_attr.process_mesh
1900
        processes = process_mesh.process_ids
1901 1902 1903 1904 1905 1906 1907 1908 1909
        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1910 1911
            parallel_axis=parallel_axis,
        )
1912
        comm_op_cost_list = build_comm_costs_from_descs(
1913 1914
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1915 1916 1917
        res.append(comm_op_cost_list)

        # calc comp op cost
1918 1919 1920 1921 1922 1923
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1924 1925 1926 1927 1928
        res.append(cost_mapping)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1929 1930
            backward_op.input("X")[0]
        )
1931
        mesh_shape = process_mesh.shape
1932
        batch_size_axis = var_dim_mapping[0]
1933 1934 1935 1936 1937
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1938 1939 1940
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1941 1942 1943
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1944 1945 1946 1947
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1948 1949 1950
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1951
        processes = dist_op.dist_attr.process_mesh.process_ids
1952 1953 1954
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
1955 1956 1957 1958

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1959 1960
            serial_op.input("Y")[0]
        )[-2]
1961 1962 1963 1964 1965 1966 1967 1968 1969
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1970 1971
            parallel_axis=parallel_axis,
        )
1972 1973

        comm_op_cost_list = build_comm_costs_from_descs(
1974 1975 1976 1977 1978 1979
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
1980 1981 1982 1983
        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1984 1985 1986
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1987 1988
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1989
        x_dims_mapping = copy.deepcopy(
1990 1991
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1992
        y_dims_mapping = copy.deepcopy(
1993 1994
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1995 1996 1997
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1998 1999
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
2000
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
2001 2002
            y_dims_mapping[-1]
        ):
2003 2004 2005 2006 2007 2008 2009
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

2010 2011 2012
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2013 2014 2015 2016 2017 2018 2019 2020 2021 2022
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

2023
    def is_auto_compatible(self, dist_op):
2024 2025 2026
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2027
            return False
2028
        if not _is_auto_compatible_for_matmul(dist_op):
2029 2030 2031
            return False
        return True

2032
    def update_dims_mapping(self, dist_op):
2033
        changed = False
2034
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
2035 2036 2037 2038
        if dim_changed:
            changed = True
        return changed

2039 2040 2041 2042 2043 2044
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

2045
        dist_op_context = ctx.dist_op_context
2046 2047 2048 2049
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
2050
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2051 2052 2053
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2054 2055

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2056
        if rank_id not in op_dist_attr.process_mesh.process_ids:
2057 2058 2059
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2060

2061
        # check validation of inputs / outputs
2062 2063
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2064 2065
                input_name
            )
2066 2067 2068 2069 2070
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2071 2072
                output_name
            )
2073 2074 2075
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2076 2077
                output_name
            )
2078

Z
zhaoyingli 已提交
2079
        X_var = main_block._var_recursive(kwargs['X'][0])
2080
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2081
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2082 2083
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
2084 2085 2086

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2087 2088
            Weight_var.name
        )[-2]
2089 2090
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2091 2092 2093 2094 2095 2096 2097
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2098 2099
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
2100 2101

        parallel_axis = matmul_row_dim_mapping
2102 2103 2104
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2105 2106
        group = new_process_group(group_ranks)

2107
        check_variable_and_dtype(
X
xu98bin 已提交
2108
            X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
2109 2110
        )
        check_dtype(
X
xu98bin 已提交
2111 2112 2113 2114
            X_var.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
2115
        )
2116
        attrs = {
2117 2118
            'trans_x': trans_x,
            'trans_y': trans_y,
2119
            OP_ROLE_KEY: src_op.attr('op_role'),
2120
        }
2121
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
2122 2123 2124 2125 2126 2127

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2128 2129 2130
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
2131

2132
        intermediate_var_0 = main_block.create_var(
2133 2134 2135
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
2136 2137 2138 2139 2140 2141
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
2142 2143
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
2144
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
2145 2146 2147
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
2148

2149 2150 2151 2152 2153 2154
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
2155 2156
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
2157 2158 2159 2160 2161 2162 2163 2164

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
2165
                'use_model_parallel': True,
2166 2167 2168
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
2169 2170 2171 2172 2173
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
2174
        matmulv2_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
2175
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2176
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2177 2178 2179 2180
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
2181 2182 2183 2184 2185
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
2186 2187 2188
        output_varname = matmul_v2_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
2189 2190 2191 2192 2193
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
2194 2195 2196
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)

        # allreduce
2197
        allreduce_op_dist_attr = OperatorDistAttr()
Z
zhaoyingli 已提交
2198
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2199
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2200 2201
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
2202
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
2203 2204
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
2205 2206 2207
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
2208 2209 2210
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
2211 2212 2213 2214 2215 2216 2217 2218
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
2219 2220

        # init param sync
2221
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
2222 2223 2224
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2225 2226 2227 2228

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
2229 2230


2231
# ReplicateParallel
2232
class DistributedMatmulV2Impl2(DistributedOperatorImpl):
2233
    def __init__(self, name):
2234
        super().__init__(name)
2235

2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        process_mesh = dist_attr.process_mesh

        # calc comp op cost
2253 2254 2255
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2256
        processes = process_mesh.process_ids
2257 2258 2259
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
2260 2261 2262 2263
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2264 2265
            backward_op.input("X")[0]
        )
2266
        mesh_shape = process_mesh.shape
2267
        batch_size_axis = var_dim_mapping[0]
2268 2269 2270 2271 2272
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2273 2274 2275
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2276 2277 2278
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2279 2280 2281 2282 2283

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2284 2285 2286
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2287
        processes = dist_op.dist_attr.process_mesh.process_ids
2288 2289 2290
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
2291 2292 2293 2294 2295

        res_cost = [cost_mapping]

        return res_cost

2296 2297 2298
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2299 2300 2301 2302 2303 2304 2305
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
2306
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
2307 2308
            x_dims_mapping[-2]
        ):
2309 2310 2311 2312
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
2313
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
2314 2315
            y_dims_mapping[-2]
        ):
2316 2317 2318
            return False
        return True

2319 2320 2321 2322 2323
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2324 2325 2326 2327 2328
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
2329
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
2330 2331
            out_dims_mapping[-2]
        ):
2332 2333 2334 2335
            return False

        return True

2336
    def is_auto_compatible(self, dist_op):
2337 2338 2339
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2340 2341
            return False

2342
        if not _is_auto_compatible_for_matmul(dist_op):
2343 2344 2345 2346
            return False

        return True

2347
    def update_dims_mapping(self, dist_op):
2348
        changed = False
2349
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
2350 2351 2352 2353
        if dim_changed:
            changed = True
        return changed

2354 2355 2356 2357
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

2358 2359 2360 2361
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

2362 2363

register_distributed_operator_impl(
2364 2365 2366 2367 2368 2369 2370 2371
    "matmul_v2", DistributedMatmulV2Impl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")
)
2372 2373 2374 2375


class DistributedMul(DistributedOperatorImplContainer):
    def __init__(self, op_type):
2376
        super().__init__(op_type)
2377 2378 2379 2380 2381 2382 2383 2384


register_distributed_operator_impl_container(DistributedMul("mul"))


# ColumnParallel
class DistributedMulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
2385
        super().__init__(name)
2386 2387 2388
        self._forward_implemented = True
        self._backward_implemented = True

2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
2405 2406
            backward_op.input("Y")[0]
        )
2407 2408 2409 2410 2411 2412 2413 2414 2415
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
2416 2417 2418
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2419
        process_mesh = dist_attr.process_mesh
2420
        processes = process_mesh.process_ids
2421 2422 2423
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
2436 2437
                parallel_axis=parallel_axis,
            )
2438
            comm_op_cost_list = build_comm_costs_from_descs(
2439 2440 2441 2442 2443 2444
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
2445 2446 2447 2448
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2449 2450
            backward_op.input("X")[0]
        )
2451
        mesh_shape = process_mesh.shape
2452
        batch_size_axis = var_dim_mapping[0]
2453 2454 2455 2456 2457
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2458 2459 2460
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2461 2462 2463
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2464 2465 2466 2467
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2468 2469 2470
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2471
        processes = dist_op.dist_attr.process_mesh.process_ids
2472 2473 2474
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2475 2476 2477 2478

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
2479 2480
            serial_op.input("Y")[0]
        )[-1]
2481 2482 2483 2484 2485 2486 2487 2488
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2489 2490
            parallel_axis=parallel_axis,
        )
2491 2492

        comm_op_cost_list = build_comm_costs_from_descs(
2493 2494
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2495 2496 2497 2498
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

2499 2500 2501 2502 2503 2504 2505 2506 2507
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_shard(x_dims_mapping[-1]):
            return False
2508
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
2509 2510
            y_dims_mapping[-1]
        ):
2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
2530 2531 2532
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2559 2560 2561
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2562 2563

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2564
        if rank_id not in op_dist_attr.process_mesh.process_ids:
2565 2566 2567
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2568 2569 2570 2571

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2572 2573
                input_name
            )
2574 2575 2576 2577 2578
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2579 2580
                output_name
            )
2581 2582 2583
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2584 2585
                output_name
            )
2586

Z
zhaoyingli 已提交
2587
        X_var = main_block._var_recursive(kwargs['X'][0])
2588
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2589
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2590 2591 2592

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
2593 2594 2595 2596 2597 2598 2599
            Weight_var.name
        )[-1]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
2600 2601
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
2602 2603

        parallel_axis = matmul_col_dim_mapping
2604 2605 2606
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2607 2608 2609 2610 2611 2612 2613
        group = new_process_group(group_ranks)

        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
2614 2615 2616
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
2617 2618 2619 2620 2621
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2622 2623 2624
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
2625 2626

        intermediate_var_0 = main_block.create_var(
2627 2628 2629
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
2630 2631 2632 2633
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
2634 2635
            stop_gradient=X_var.stop_gradient,
        )
2636
        # set intermediate_var_0's dist_attr with X_var's dist_attr
2637 2638 2639
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
2640 2641

        check_variable_and_dtype(
2642 2643
            X_var,
            'tensor',
X
xu98bin 已提交
2644
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
2645 2646
            '_c_identity',
        )
2647 2648 2649 2650 2651 2652 2653 2654
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
2655 2656 2657
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
2658 2659 2660
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)

2661
        check_variable_and_dtype(
X
xu98bin 已提交
2662 2663 2664 2665
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
2666 2667 2668 2669
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
X
xu98bin 已提交
2670
            ['float16', 'float32', 'float64', 'uint16'],
2671 2672
            'linear',
        )
2673 2674 2675
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
2676
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
2677
            OP_ROLE_KEY: src_op.attr('op_role'),
2678
        }
2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690
        inputs = {'X': intermediate_var_0, 'Y': Weight_var}

        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            if var_name == "X":
                var = X_var
            else:
                var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
2691 2692 2693
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
2694 2695 2696
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

2697 2698 2699
        mul_op = main_block.append_op(
            type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
2700 2701 2702
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

2703 2704 2705 2706 2707
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

2708 2709
        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
2710
        identity_op_dist_attr = OperatorDistAttr()
2711 2712 2713 2714 2715 2716 2717
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
2718 2719 2720 2721 2722
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
2723 2724
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
2725 2726 2727
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
2728 2729 2730
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
2731
        matmulv2_op_dist_attr = OperatorDistAttr()
2732 2733 2734 2735 2736 2737
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
2738 2739
                    input_varname
                )
2740
                assert input_dist_attr is not None, "dist_attr is {}".format(
2741 2742
                    op_dist_attr
                )
2743
                matmulv2_op_dist_attr.set_input_dist_attr(
2744 2745
                    input_varname, input_dist_attr
                )
2746
            else:
Z
zhaoyingli 已提交
2747
                input_var = main_block._var_recursive(input_varname)
2748
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
2749 2750
                    input_var
                )
2751
                matmulv2_op_dist_attr.set_input_dist_attr(
2752 2753
                    input_varname, tensor_dist_attr
                )
2754 2755 2756
        for output_varname in mul_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
2757 2758 2759 2760 2761
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
2762 2763 2764 2765
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
2766 2767 2768
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2769 2770 2771 2772 2773 2774 2775 2776 2777

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# RowParallel
class DistributedMulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
2778
        super().__init__(name)
2779 2780 2781
        self._forward_implemented = True
        self._backward_implemented = True

2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797 2798
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        process_mesh = dist_attr.process_mesh
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
2799 2800
            backward_op.input("Y")[0]
        )
2801 2802 2803 2804 2805 2806 2807 2808 2809 2810 2811 2812
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2813 2814
            parallel_axis=parallel_axis,
        )
2815
        processes = process_mesh.process_ids
2816
        comm_op_cost_list = build_comm_costs_from_descs(
2817 2818
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2819 2820 2821
        res.append(comm_op_cost_list)

        # calc comp op cost
2822 2823 2824 2825 2826 2827
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2828 2829 2830 2831
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2832 2833
            backward_op.input("X")[0]
        )
2834
        mesh_shape = process_mesh.shape
2835
        batch_size_axis = var_dim_mapping[0]
2836 2837 2838 2839 2840
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2841 2842 2843
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2844 2845 2846
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2847 2848 2849 2850
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2851 2852 2853
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2854
        processes = dist_op.dist_attr.process_mesh.process_ids
2855 2856 2857
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2858 2859 2860 2861

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
2862 2863
            serial_op.input("Y")[0]
        )[-2]
2864 2865 2866 2867 2868 2869 2870 2871 2872
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2873 2874
            parallel_axis=parallel_axis,
        )
2875 2876

        comm_op_cost_list = build_comm_costs_from_descs(
2877 2878 2879 2880 2881 2882
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
2883 2884 2885 2886 2887

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

2888 2889 2890 2891 2892 2893 2894 2895 2896
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
2897
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
2898 2899
            y_dims_mapping[-1]
        ):
2900 2901 2902 2903 2904 2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918 2919 2920
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
2921 2922 2923
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2924 2925 2926 2927 2928 2929 2930 2931 2932 2933 2934 2935 2936 2937 2938 2939 2940 2941 2942 2943 2944 2945 2946 2947 2948 2949
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2950 2951 2952
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2953 2954

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2955
        if rank_id not in op_dist_attr.process_mesh.process_ids:
2956 2957 2958
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2959 2960 2961 2962

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2963 2964
                input_name
            )
2965 2966 2967 2968 2969
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2970 2971
                output_name
            )
2972 2973 2974
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2975 2976
                output_name
            )
2977

Z
zhaoyingli 已提交
2978
        X_var = main_block._var_recursive(kwargs['X'][0])
2979
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2980
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2981 2982 2983

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2984 2985 2986 2987 2988 2989 2990
            Weight_var.name
        )[-2]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2991 2992
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
2993 2994

        parallel_axis = matmul_row_dim_mapping
2995 2996 2997
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2998 2999
        group = new_process_group(group_ranks)

3000
        check_variable_and_dtype(
X
xu98bin 已提交
3001
            X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
3002 3003
        )
        check_dtype(
X
xu98bin 已提交
3004 3005 3006 3007
            X_var.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
3008
        )
3009 3010 3011
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
3012
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
3013
            OP_ROLE_KEY: src_op.attr('op_role'),
3014 3015 3016 3017 3018 3019 3020 3021
        }
        inputs = {'X': X_var, 'Y': Weight_var}

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
3022 3023 3024
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
3025 3026

        intermediate_var_0 = main_block.create_var(
3027 3028 3029
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
3030 3031 3032 3033 3034 3035
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
3036 3037
            need_check_feed=Out_var.desc.need_check_feed(),
        )
3038
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
3039 3040 3041
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
3042

3043 3044 3045 3046 3047 3048 3049
        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
3050 3051 3052
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
3053 3054 3055
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

3056 3057 3058 3059 3060 3061
        mul_op = main_block.append_op(
            type='mul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
3062

3063 3064 3065
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)

3066 3067 3068 3069 3070
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

3071 3072 3073 3074 3075 3076 3077
        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
3078
                'use_model_parallel': True,
3079 3080 3081
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
3082

3083 3084 3085 3086 3087
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
3088
        matmulv2_op_dist_attr = OperatorDistAttr()
3089 3090 3091 3092 3093 3094
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
3095 3096 3097 3098 3099
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
3100 3101 3102
        output_varname = mul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
3103 3104 3105 3106 3107
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
3108 3109 3110
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # allreduce
3111
        allreduce_op_dist_attr = OperatorDistAttr()
3112 3113 3114 3115
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
3116
            input_var = main_block._var_recursive(input_varname)
3117 3118
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
3119 3120 3121
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
3122 3123 3124
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
3125 3126 3127 3128 3129 3130 3131 3132
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
3133 3134 3135

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
3136 3137 3138
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
3139 3140 3141 3142 3143 3144 3145 3146 3147

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# ReplicateParallel
class DistributedMulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
3148
        super().__init__(name)
3149

3150 3151 3152 3153 3154 3155 3156 3157 3158 3159 3160 3161 3162 3163 3164 3165
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block

        # calc comp op cost
3166 3167 3168
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3169
        process_mesh = dist_attr.process_mesh
3170
        processes = process_mesh.process_ids
3171 3172 3173
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
3174 3175 3176 3177
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
3178 3179
            backward_op.input("X")[0]
        )
3180
        mesh_shape = process_mesh.shape
3181
        batch_size_axis = var_dim_mapping[0]
3182 3183 3184 3185 3186
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
3187 3188 3189
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
3190 3191 3192
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
3193 3194 3195 3196 3197

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
3198 3199 3200
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3201
        processes = dist_op.dist_attr.process_mesh.process_ids
3202 3203 3204
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
3205 3206 3207 3208

        res_cost = [cost_mapping]
        return res_cost

3209 3210 3211 3212 3213 3214 3215 3216 3217 3218
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
3219
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
3220 3221
            x_dims_mapping[-2]
        ):
3222 3223 3224
            return False
        if is_dim_shard(y_dims_mapping[-1]):
            return False
3225
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
3226 3227
            y_dims_mapping[-2]
        ):
3228 3229 3230 3231 3232 3233 3234 3235 3236 3237 3238 3239 3240
            return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
3241
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
3242 3243
            out_dims_mapping[-2]
        ):
3244 3245 3246 3247 3248
            return False

        return True

    def is_auto_compatible(self, dist_op):
3249 3250 3251
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
3252 3253 3254 3255 3256 3257 3258 3259 3260 3261 3262 3263 3264 3265 3266 3267 3268 3269 3270 3271 3272 3273 3274
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


3275 3276 3277
register_distributed_operator_impl(
    "mul", DistributedMulImpl0("column_parallel")
)
3278
register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel"))
3279 3280 3281
register_distributed_operator_impl(
    "mul", DistributedMulImpl2("replicate_parallel")
)