dist_matmul.py 117.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
import copy
C
caozhou 已提交
16

17 18 19 20
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
    AllreduceSumOpCost,
    IdentityOpCost,
)
21
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
22 23 24
from paddle.fluid import core, unique_name
from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype

25
from ..cost import (
26 27 28 29 30 31 32
    MatmulGradOpCost,
    MatmulOpCost,
    MatmulV2GradOpCost,
    MatmulV2OpCost,
    MulGradOpCost,
    MulOpCost,
    build_comm_costs_from_descs,
33
    build_comm_desc_from_dist_op,
34 35
    build_comp_costs_from_descs,
    build_comp_desc_from_dist_op,
36 37
    build_dp_costs,
)
38 39 40 41 42 43 44 45 46 47 48
from ..dist_attribute import OperatorDistributedAttribute
from ..process_group import new_process_group
from ..utils import (
    _get_comm_group,
    _get_corresponding_rank,
    compute_compatible_and_update_dim_mapping,
    compute_compatible_dims_mapping,
    is_dim_replicate,
    is_dim_shard,
    is_valid_list_index,
    set_dist_op_desc_original_id,
49
)
50 51 52 53 54 55 56 57 58 59 60
from .common import (
    DistributedOperatorImpl,
    DistributedOperatorImplContainer,
    gradient_synchronization,
    infer_shape,
    is_parameter_related,
    register_distributed_operator_impl,
    register_distributed_operator_impl_container,
    set_comm_op_dist_attr_for_program,
)
from .dist_default import DistributedDefaultImpl0
61 62


63 64
def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping):
    if trans_x:
65 66 67 68
        x_dims_mapping[-1], x_dims_mapping[-2] = (
            x_dims_mapping[-2],
            x_dims_mapping[-1],
        )
69
    if trans_y:
70 71 72 73
        y_dims_mapping[-1], y_dims_mapping[-2] = (
            y_dims_mapping[-2],
            y_dims_mapping[-1],
        )
74 75


76
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
77
    dist_op_desc = block.append_op(type='nop').desc
78
    dist_op_desc.copy_from(src_op.desc)
79
    set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
80 81 82 83 84 85 86 87 88 89
    for input_name in src_op.desc.input_names():
        assert input_name in kwargs
        dist_op_desc.set_input(input_name, kwargs[input_name])
    for output_name in src_op.desc.output_names():
        assert input_name in kwargs
        dist_op_desc.set_output(output_name, kwargs[output_name])

    return dist_op_desc


90
def _update_dims_mapping_for_matmul(dist_op):
91
    changed = False
92 93
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
94 95 96
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
C
caozhou 已提交
97 98 99 100 101 102 103 104
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
105 106 107 108 109 110 111 112 113
    x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
    y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
    out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
C
caozhou 已提交
114
        assert trans_x is False
115
        x_dims_mapping.insert(0, -1)
C
caozhou 已提交
116
        out_dims_mapping.insert(out_dims_mapping_len - 1, 0)
117
    if y_dims_mapping_len == 1:
C
caozhou 已提交
118
        assert trans_y is False
119
        y_dims_mapping.insert(1, -1)
C
caozhou 已提交
120
        out_dims_mapping.insert(out_dims_mapping_len, 0)
121

122 123
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)

C
caozhou 已提交
124 125 126
    new_x_dims_mapping_len = len(x_dims_mapping)
    new_y_dims_mapping_len = len(y_dims_mapping)
    new_out_dims_mapping_len = len(out_dims_mapping)
127
    # Deal with dim > 2 and take care of broadcasting
C
caozhou 已提交
128
    if new_out_dims_mapping_len > 2:
129 130 131 132
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

C
caozhou 已提交
133
        for i in range(new_out_dims_mapping_len - new_x_dims_mapping_len):
134
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
135
        for i in range(new_x_dims_mapping_len - 2):
136 137
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

C
caozhou 已提交
138
        for i in range(new_out_dims_mapping_len - new_y_dims_mapping_len):
139
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
140
        for i in range(new_y_dims_mapping_len - 2):
141 142
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

C
caozhou 已提交
143
        for i in range(new_out_dims_mapping_len - 2):
144 145
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

146 147 148 149 150 151 152
        compatible_dims_mapping = compute_compatible_dims_mapping(
            [
                broadcast_x_dims_mapping,
                broadcast_y_dims_mapping,
                broadcast_out_dims_mapping,
            ]
        )
153
        if compatible_dims_mapping is None:
154 155 156
            trans_x_y_dims_mapping(
                trans_x, trans_y, x_dims_mapping, y_dims_mapping
            )
157
            return False
158

C
caozhou 已提交
159 160
        for i in range(new_x_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_x_dims_mapping_len)
161 162 163 164
            if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                x_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
165 166
        for i in range(new_y_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_y_dims_mapping_len)
167 168 169 170
            if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                y_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
171
        for i in range(new_out_dims_mapping_len - 2):
172 173 174 175
            if out_dims_mapping[i] != compatible_dims_mapping[i]:
                out_dims_mapping[i] = compatible_dims_mapping[i]
                changed = True

176
    # The following which uses negative index can be work
177 178
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
    dim_changed = compute_compatible_and_update_dim_mapping(
179 180
        [x_dims_mapping, y_dims_mapping], [-1, -2]
    )
181 182 183 184
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
185 186
        [x_dims_mapping, out_dims_mapping], [-2, -2]
    )
187 188 189 190
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
191 192
        [y_dims_mapping, out_dims_mapping], [-1, -1]
    )
193 194 195
    if dim_changed:
        changed = True

196
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
C
caozhou 已提交
197

198
    # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
199 200
    if x_dims_mapping_len == 1:
        x_dims_mapping.pop(0)
C
caozhou 已提交
201
        out_dims_mapping.pop(out_dims_mapping_len - 1)
202 203
    if y_dims_mapping_len == 1:
        y_dims_mapping.pop(1)
C
caozhou 已提交
204
        out_dims_mapping.pop(out_dims_mapping_len)
205 206 207 208 209 210 211 212

    assert len(x_dims_mapping) == x_dims_mapping_len
    assert len(y_dims_mapping) == y_dims_mapping_len
    assert len(out_dims_mapping) == out_dims_mapping_len

    return changed


213 214 215 216 217 218
def _is_auto_compatible_for_matmul(dist_op):
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
219 220 221 222 223 224 225 226 227
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')

228 229 230 231
    # Deep copy these dims_mappings for keeping them unchanged.
    x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
    y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
    out_dims_mapping = copy.deepcopy(
232 233
        op_dist_attr.get_output_dims_mapping(out_name)
    )
234 235 236 237 238 239 240 241 242 243
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
        x_dims_mapping.insert(0, -1)
    if y_dims_mapping_len == 1:
        y_dims_mapping.insert(1, -1)

244
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
245

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    # Deal with dim > 2 and take care of broadcasting
    if out_dims_mapping_len > 2:
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

        for i in range(out_dims_mapping_len - x_dims_mapping_len):
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
        for i in range(x_dims_mapping_len - 2):
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

        for i in range(out_dims_mapping_len - y_dims_mapping_len):
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
        for i in range(y_dims_mapping_len - 2):
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

        for i in range(out_dims_mapping_len - 2):
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

265 266 267
        is_same = (broadcast_x_dims_mapping == broadcast_y_dims_mapping) and (
            broadcast_x_dims_mapping == broadcast_out_dims_mapping
        )
268 269 270 271 272
        if not is_same:
            return False

    # The following which uses negative index can be work
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
273
    is_same = x_dims_mapping[-1] == y_dims_mapping[-2]
274 275 276
    if not is_same:
        return False

277
    is_same = x_dims_mapping[-2] == out_dims_mapping[-2]
278 279 280
    if not is_same:
        return False

281
    is_same = y_dims_mapping[-1] == out_dims_mapping[-1]
282 283 284 285 286 287
    if not is_same:
        return False

    return True


288 289 290 291
def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):

    # by now the backward function only insert the gradient allreduce for dist op itself

292
    dist_op_context = ctx.dist_op_context
293 294 295
    main_block = dist_op_context.work_block
    backward_op = dist_op_context.cur_src_op
    rank_id = dist_op_context.rank_id
296
    dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
297 298 299
    assert (
        dist_attr is not None
    ), "backward op [{}] don't have dist attribute !".format(str(backward_op))
300 301

    # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
302
    if rank_id not in dist_attr.process_mesh.process_ids:
303
        rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
304 305 306 307 308 309

    assert 'Y' in kwargs, "input [{}] is not given".format('Y')
    assert 'X' in kwargs, "input [{}] is not given".format('X')
    assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
    assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
    assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
310 311 312
    assert (
        len(kwargs['Y']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
313
        kwargs['Y']
314 315 316 317
    )
    assert (
        len(kwargs['X']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
318
        kwargs['X']
319 320 321 322 323 324 325 326 327
    )
    assert (
        len(kwargs['Out@GRAD']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
        kwargs['Out']
    )
    assert (
        len(kwargs['Y@GRAD']) == 1
    ), "row_parallel_embedding output Ids take 1 variable but got {}".format(
328
        kwargs['Y@GRAD']
329
    )
330

Z
zhaoyingli 已提交
331
    X_var = main_block._var_recursive(kwargs['X'][0])
332
    Y_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
333 334
    Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0])
    Y_grad = main_block._var_recursive(kwargs['Y@GRAD'][0])
335

J
JZ-LIANG 已提交
336 337 338
    assert not is_parameter_related(
        X_var.name, main_block
    ), "left operand(X) [{}] of dist matmul should not be parameter".format(
339 340
        X_var.name
    )
341

342
    X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
343
    Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
344 345
    process_mesh_shape = dist_attr.process_mesh.shape
    process_mesh_group = dist_attr.process_mesh.process_ids
346 347 348 349 350 351 352 353 354 355 356 357 358

    trans_x = None
    trans_y = None
    if backward_op.desc.type() == "matmul_v2_grad":
        trans_x = backward_op.desc.attr('trans_x')
        trans_y = backward_op.desc.attr('trans_y')
    elif backward_op.desc.type() == "matmul_grad":
        trans_x = backward_op.desc.attr('transpose_X')
        trans_y = backward_op.desc.attr('transpose_Y')

    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

359 360 361 362
    # assert len(
    #     Y_var_dim_mapping
    # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
    #     Y_var.name, Y_var_dim_mapping)
363 364 365 366 367 368
    Y_var_partitioned = False
    for dim in Y_var_dim_mapping:
        if dim >= 0 and process_mesh_shape[dim] > 0:
            Y_var_partitioned = True
            break

J
JZ-LIANG 已提交
369
    if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:
370 371 372 373 374 375 376

        if Y_var_dim_mapping[0] >= 0:
            # row parallel: c_identity + matmul
            assert Y_var_dim_mapping[1] < 0
            parallel_axis = Y_var_dim_mapping[0]

            check_variable_and_dtype(
377 378
                Out_grad,
                'tensor',
X
xu98bin 已提交
379
                ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
380 381
                '_c_identity',
            )
382 383

            intermediate_var_0 = main_block.create_var(
384 385 386 387
                name=unique_name.generate_with_ignorable_key(
                    ".".join(["c_identity", 'tmp'])
                )
                + "@GRAD",
388 389 390 391
                dtype=Out_grad.dtype,
                shape=Out_grad.shape,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
392 393
                stop_gradient=Out_grad.stop_gradient,
            )
394 395 396 397

            # copy X_var's dist_attr to intermediate_var_0's dist_attr
            out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
            assert out_grad_dist_attr is not None
398 399 400
            ctx.set_tensor_dist_attr_for_program(
                intermediate_var_0, out_grad_dist_attr
            )
401

402 403 404
            group_ranks = _get_comm_group(
                process_mesh_group, process_mesh_shape, parallel_axis, rank_id
            )
405 406 407 408 409 410 411 412 413 414
            group = new_process_group(group_ranks)
            c_identity_op = main_block.append_op(
                type='c_identity',
                inputs={'X': [Out_grad]},
                outputs={'Out': intermediate_var_0},
                attrs={
                    'ring_id': group.id,
                    'use_calc_stream': True,
                    'use_model_parallel': True,
                    OP_ROLE_KEY: OpRole.Backward,
415 416 417 418 419
                },
            )
            check_variable_and_dtype(
                intermediate_var_0,
                'x',
X
xu98bin 已提交
420
                ['float16', 'float32', 'float64', 'uint16'],
421 422 423 424 425
                'linear',
            )
            check_dtype(
                intermediate_var_0.dtype,
                'dtype',
X
xu98bin 已提交
426
                ['float16', 'float32', 'float64', 'uint16'],
427 428 429 430 431
                'linear',
            )
            set_comm_op_dist_attr_for_program(
                c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
            )
432 433 434 435

            new_kwargs = copy.deepcopy(kwargs)
            new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
            matmul_op_desc = copy_op_with_new_input_output(
436 437
                ctx, main_block, backward_op, **new_kwargs
            )
438 439 440 441 442 443 444 445 446 447
        else:
            # col parallel: matmul + allreduce
            assert Y_var_dim_mapping[0] < 0
            parallel_axis = Y_var_dim_mapping[1]
            new_kwargs = copy.deepcopy(kwargs)

            # NOTE (JZ-LIANG) should allow left operand be empty for matmul grad
            has_x_grad = len(kwargs['X@GRAD']) > 0
            if has_x_grad:
                assert len(kwargs['X@GRAD']) == 1
Z
zhaoyingli 已提交
448
                X_grad = main_block._var_recursive(kwargs['X@GRAD'][0])
449
                intermediate_var_0 = main_block.create_var(
450 451 452 453
                    name=unique_name.generate_with_ignorable_key(
                        ".".join(["c_identity", 'tmp'])
                    )
                    + "@GRAD",
454 455 456 457
                    dtype=X_grad.dtype,
                    shape=X_grad.shape,
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    persistable=False,
458 459
                    stop_gradient=X_grad.stop_gradient,
                )
460 461 462

                X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
                assert X_grad_dist_attr is not None
463 464 465
                ctx.set_tensor_dist_attr_for_program(
                    intermediate_var_0, X_grad_dist_attr
                )
466 467 468
                new_kwargs['X@GRAD'] = [intermediate_var_0.name]

            matmul_op_desc = copy_op_with_new_input_output(
469 470
                ctx, main_block, backward_op, **new_kwargs
            )
471 472 473

            # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
            if has_x_grad:
474 475 476 477 478 479
                group_ranks = _get_comm_group(
                    process_mesh_group,
                    process_mesh_shape,
                    parallel_axis,
                    rank_id,
                )
480 481 482 483 484 485 486 487 488
                group = new_process_group(group_ranks)
                c_allreduce_sum_op = main_block.append_op(
                    type='c_allreduce_sum',
                    inputs={'X': [intermediate_var_0.name]},
                    outputs={'Out': kwargs['X@GRAD']},
                    attrs={
                        'ring_id': group.id,
                        'use_calc_stream': True,
                        'use_model_parallel': True,
489 490 491 492 493 494 495 496 497
                        OP_ROLE_KEY: OpRole.Backward,
                    },
                )
                set_comm_op_dist_attr_for_program(
                    c_allreduce_sum_op,
                    dist_attr.process_mesh,
                    X_grad_dist_attr,
                    ctx,
                )
498 499
    else:
        # replicate
500 501 502
        matmul_op_desc = copy_op_with_new_input_output(
            ctx, main_block, backward_op, **kwargs
        )
503

504 505 506 507 508 509 510
    # data parallel gradient synchronization
    act_grad_names = [X_var.name]

    out_grad_names = []
    if is_parameter_related(Y_var.name, main_block):
        out_grad_names = [kwargs['Y@GRAD'][0]]

511 512 513
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)

514 515 516
    gradient_synchronization(
        ctx, backward_op, act_grad_names, out_grad_names, rank_id
    )
517

518 519 520 521 522
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

523

524
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
525

526 527
    if Weight_var.name in dist_op_context.already_init_sync_vars:
        return
528
    assert startup_block.has_var(Weight_var.name)
529
    dist_op_context.already_init_sync_vars.add(Weight_var.name)
530
    param = startup_block.var(Weight_var.name)
531 532 533
    param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
    process_mesh = param_dist_attr.process_mesh
    dim_mapping = param_dist_attr.dims_mapping
534

535
    for axis, size in enumerate(process_mesh.shape):
536 537 538
        if size <= 1 or axis in dim_mapping:
            pass
        else:
539
            group_ranks = _get_comm_group(
540
                process_mesh.process_ids, process_mesh.shape, axis, rank_id
541
            )
542 543
            sync_group = new_process_group(group_ranks)

544 545 546 547 548 549 550 551 552 553 554
            startup_block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': sync_group.id,
                    'root': 0,
                    'use_calc_stream': True,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
555 556


557
class DistributedMatmul(DistributedOperatorImplContainer):
558
    def __init__(self, op_type):
559
        super().__init__(op_type)
560 561


562
register_distributed_operator_impl_container(DistributedMatmul("matmul"))
563 564 565 566 567


# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
568
        super().__init__(name)
569
        self._forward_implemented = True
570
        self._backward_implemented = True
571

C
caozhou 已提交
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
588 589
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
590 591 592 593 594 595 596 597 598
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
599 600 601
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
602
        process_mesh = dist_attr.process_mesh
603
        processes = process_mesh.process_ids
604 605 606
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
607 608 609 610 611 612 613 614 615 616 617 618
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
619 620
                parallel_axis=parallel_axis,
            )
C
caozhou 已提交
621
            comm_op_cost_list = build_comm_costs_from_descs(
622 623 624 625 626 627
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
C
caozhou 已提交
628 629 630 631
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
632 633
            backward_op.input("X")[0]
        )
634
        mesh_shape = process_mesh.shape
C
caozhou 已提交
635
        batch_size_axis = var_dim_mapping[0]
636 637 638 639 640
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
641 642 643
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
644 645 646
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
647 648 649 650
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
651 652 653
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
654
        processes = dist_op.dist_attr.process_mesh.process_ids
655 656 657
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
658 659 660 661

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
662 663
            serial_op.input("Y")[0]
        )[-1]
C
caozhou 已提交
664 665 666 667 668 669 670 671
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
672 673
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
674 675

        comm_op_cost_list = build_comm_costs_from_descs(
676 677
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
678 679 680 681
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

682 683 684
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
685 686
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
687
        x_dims_mapping = copy.deepcopy(
688 689
            op_dist_attr.get_input_dims_mapping(x_name)
        )
690
        y_dims_mapping = copy.deepcopy(
691 692
            op_dist_attr.get_input_dims_mapping(y_name)
        )
693 694 695
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
696 697
        if is_dim_shard(x_dims_mapping[-1]):
            return False
698
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
699 700
            y_dims_mapping[-1]
        ):
701 702 703 704 705 706
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

707 708 709
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
710 711 712 713 714 715 716 717 718
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

719
    def is_auto_compatible(self, dist_op):
720 721 722
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
723
            return False
724
        if not _is_auto_compatible_for_matmul(dist_op):
725 726 727
            return False
        return True

728
    def update_dims_mapping(self, dist_op):
729
        changed = False
730
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
731 732 733 734
        if dim_changed:
            changed = True
        return changed

735 736 737 738 739 740
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

741
        dist_op_context = ctx.dist_op_context
742 743 744 745
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
746
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
747 748 749
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
750 751

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
752
        if rank_id not in op_dist_attr.process_mesh.process_ids:
753 754 755
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
756

757
        # check validation of inputs / outputs
758 759
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
760 761
                input_name
            )
762 763 764 765 766
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
767 768
                output_name
            )
769 770 771
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
772 773
                output_name
            )
774

Z
zhaoyingli 已提交
775 776 777
        X_var = main_block._var_recursive(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block._var_recursive(kwargs['Out'][0])
778 779
        trans_x = src_op.attr("transpose_X")
        trans_y = src_op.attr("transpose_Y")
780 781 782

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
783 784
            Weight_var.name
        )[-1]
785 786
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
787 788 789 790 791 792 793
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
794 795
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
796 797

        parallel_axis = matmul_col_dim_mapping
798 799 800
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
801 802
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
803 804 805 806 807
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
808 809 810
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
811 812 813 814 815
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
816 817 818
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
819

820
        intermediate_var_0 = main_block.create_var(
821 822 823
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
824 825 826 827
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
828 829
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
830
        # set intermediate_var_0's dist_attr with X_var's dist_attr
831 832 833
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
834 835

        check_variable_and_dtype(
836 837
            X_var,
            'tensor',
X
xu98bin 已提交
838
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
839 840
            '_c_identity',
        )
841 842 843 844 845 846 847 848 849

        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
850 851 852
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
853 854
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
855

856
        check_variable_and_dtype(
X
xu98bin 已提交
857 858 859 860
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
861 862 863 864
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
X
xu98bin 已提交
865
            ['float16', 'float32', 'float64', 'uint16'],
866 867
            'linear',
        )
868
        attrs = {
869 870
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
871
            'alpha': 1,
872
            OP_ROLE_KEY: src_op.attr('op_role'),
873 874
        }
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
875 876 877
        matmul_op = main_block.append_op(
            type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
Z
zhaoyingli 已提交
878 879 880 881 882 883 884
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
885
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
886 887 888 889 890
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
891 892 893 894 895
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
896 897
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
898 899 900
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
901 902 903 904 905 906
        # set op dist attr
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
907
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
908 909 910 911 912
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        for input_varname in matmul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
913 914
                    input_varname
                )
Z
zhaoyingli 已提交
915
                assert input_dist_attr is not None, "dist_attr is {}".format(
916 917 918 919 920
                    op_dist_attr
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
921
            else:
Z
zhaoyingli 已提交
922
                input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
923
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
924 925 926 927 928
                    input_var
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
929 930 931 932
        # output
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
        assert output_dist_attr is not None, "dist_attr is {}".format(
933 934 935 936 937
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
938 939
        # set op dist attr
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
940 941

        # init param sync
942
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
943 944 945
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
946 947 948 949

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
950

951 952 953 954

# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
955
        super().__init__(name)
956
        self._forward_implemented = True
957
        self._backward_implemented = True
958

C
caozhou 已提交
959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
975 976
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
977 978 979 980 981 982 983 984 985 986 987 988
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
989 990
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
991
        process_mesh = dist_attr.process_mesh
992
        processes = process_mesh.process_ids
C
caozhou 已提交
993
        comm_op_cost_list = build_comm_costs_from_descs(
994 995
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
996 997 998
        res.append(comm_op_cost_list)

        # calc comp op cost
999 1000 1001 1002 1003 1004
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1005 1006 1007 1008
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1009 1010
            backward_op.input("X")[0]
        )
1011
        mesh_shape = process_mesh.shape
C
caozhou 已提交
1012
        batch_size_axis = var_dim_mapping[0]
1013 1014 1015 1016 1017
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1018 1019 1020
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1021 1022 1023
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1024 1025 1026 1027
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1028 1029 1030
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1031
        processes = dist_op.dist_attr.process_mesh.process_ids
1032 1033 1034
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1035 1036 1037 1038

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1039 1040
            serial_op.input("Y")[0]
        )[-2]
C
caozhou 已提交
1041 1042 1043 1044 1045 1046 1047 1048 1049
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1050 1051
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
1052 1053

        comm_op_cost_list = build_comm_costs_from_descs(
1054 1055 1056 1057 1058 1059
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
C
caozhou 已提交
1060 1061 1062 1063 1064

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1065 1066 1067
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1068 1069
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1070
        x_dims_mapping = copy.deepcopy(
1071 1072
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1073
        y_dims_mapping = copy.deepcopy(
1074 1075
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1076 1077 1078
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1079 1080
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1081
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
1082 1083
            y_dims_mapping[-1]
        ):
1084 1085 1086 1087 1088 1089 1090
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1091 1092 1093
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1094 1095 1096 1097 1098 1099 1100 1101 1102 1103
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1104
    def is_auto_compatible(self, dist_op):
1105 1106 1107
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1108
            return False
1109
        if not _is_auto_compatible_for_matmul(dist_op):
1110 1111 1112
            return False
        return True

1113
    def update_dims_mapping(self, dist_op):
1114
        changed = False
1115
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1116 1117 1118 1119
        if dim_changed:
            changed = True
        return changed

1120 1121 1122 1123 1124 1125
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1126
        dist_op_context = ctx.dist_op_context
1127 1128 1129 1130
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1131
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1132 1133 1134
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1135 1136

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1137
        if rank_id not in op_dist_attr.process_mesh.process_ids:
1138 1139 1140
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1141

1142
        # check validation of inputs / outputs
1143 1144
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
1145 1146
                input_name
            )
1147 1148 1149 1150 1151
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
1152 1153
                output_name
            )
1154 1155 1156
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
1157 1158
                output_name
            )
1159

Z
zhaoyingli 已提交
1160 1161 1162
        X_var = main_block._var_recursive(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block._var_recursive(kwargs['Out'][0])
1163 1164
        trans_x = src_op.attr('transpose_X')
        trans_y = src_op.attr('transpose_Y')
1165 1166 1167

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
1168 1169
            Weight_var.name
        )[-2]
1170 1171
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
1172 1173 1174 1175 1176 1177 1178
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
1179 1180
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
1181 1182

        parallel_axis = matmul_row_dim_mapping
1183 1184 1185
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1186 1187
        group = new_process_group(group_ranks)

1188
        check_variable_and_dtype(
X
xu98bin 已提交
1189
            X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
1190 1191
        )
        check_dtype(
X
xu98bin 已提交
1192 1193 1194 1195
            X_var.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
1196
        )
1197
        attrs = {
1198 1199
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
1200
            'alpha': 1,
1201
            OP_ROLE_KEY: src_op.attr('op_role'),
1202 1203
        }
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
1204 1205 1206 1207 1208 1209

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
1210 1211 1212
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1213

1214
        intermediate_var_0 = main_block.create_var(
1215 1216 1217
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
1218 1219 1220 1221 1222 1223
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
1224 1225
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
1226
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
1227 1228 1229
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
1230

1231 1232 1233 1234 1235 1236
        matmul_op = main_block.append_op(
            type='matmul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1237 1238
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
1239 1240 1241 1242 1243 1244 1245 1246

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
1247
                'use_model_parallel': True,
1248 1249 1250
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
1251 1252 1253 1254 1255 1256 1257
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1258
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1259 1260 1261 1262
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
1263 1264 1265 1266 1267
                op_dist_attr
            )
            matmul_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
1268 1269 1270
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
1271 1272 1273 1274 1275
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
1276 1277 1278 1279 1280
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1281
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1282 1283
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
1284
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
1285 1286
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
1287 1288 1289
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
1290 1291 1292
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
1293 1294 1295 1296 1297 1298 1299 1300
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
1301 1302

        # init param sync
1303
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1304 1305 1306
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1307 1308 1309 1310

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1311

1312

1313
# ReplicateParallel
1314 1315
class DistributedMatmulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
1316
        super().__init__(name)
1317

C
caozhou 已提交
1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block

        # calc comp op cost
1334 1335 1336
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
1337
        process_mesh = dist_attr.process_mesh
1338
        processes = process_mesh.process_ids
1339 1340 1341
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1342 1343 1344 1345
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1346 1347
            backward_op.input("X")[0]
        )
1348
        mesh_shape = process_mesh.shape
C
caozhou 已提交
1349
        batch_size_axis = var_dim_mapping[0]
1350 1351 1352 1353 1354
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1355 1356 1357
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1358 1359 1360
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1361 1362 1363 1364 1365

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1366 1367 1368
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1369
        processes = dist_op.dist_attr.process_mesh.process_ids
1370 1371 1372
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1373 1374 1375 1376

        res_cost = [cost_mapping]
        return res_cost

1377 1378 1379
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1380 1381 1382 1383 1384 1385 1386
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
1387
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
1388 1389
            x_dims_mapping[-2]
        ):
1390 1391 1392 1393
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
1394
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
1395 1396
            y_dims_mapping[-2]
        ):
1397 1398 1399 1400
            return False

        return True

1401 1402 1403
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1404 1405 1406 1407 1408
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
1409
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
1410 1411
            out_dims_mapping[-2]
        ):
1412 1413 1414 1415
            return False

        return True

1416
    def is_auto_compatible(self, dist_op):
1417 1418 1419
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1420 1421
            return False

1422
        if not _is_auto_compatible_for_matmul(dist_op):
1423 1424 1425 1426
            return False

        return True

1427
    def update_dims_mapping(self, dist_op):
1428
        changed = False
1429
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1430 1431 1432 1433
        if dim_changed:
            changed = True
        return changed

1434 1435 1436 1437
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

1438 1439 1440 1441
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

1442

1443 1444 1445 1446 1447 1448 1449 1450 1451
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl2("replicate_parallel")
)
1452 1453


1454
class DistributedMatmulV2(DistributedOperatorImplContainer):
1455
    def __init__(self, op_type):
1456
        super().__init__(op_type)
1457 1458


1459
register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
1460 1461


1462 1463 1464
# ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl):
    def __init__(self, name):
1465
        super().__init__(name)
1466
        self._forward_implemented = True
1467
        self._backward_implemented = True
1468

1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
1485 1486
            backward_op.input("Y")[0]
        )
1487
        process_mesh = dist_attr.process_mesh
1488
        processes = process_mesh.process_ids
1489
        # col parallel: matmul + allreduce
1490 1491
        if backward_op.attr("trans_y"):
            Y_var_dim_mapping.reverse()
1492 1493 1494 1495 1496 1497 1498 1499
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
1500 1501 1502
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1503

1504 1505 1506
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
1519 1520
                parallel_axis=parallel_axis,
            )
1521
            comm_op_cost_list = build_comm_costs_from_descs(
1522 1523 1524 1525 1526 1527
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
1528 1529 1530 1531 1532
            res.append(comm_op_cost_list)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1533 1534
            backward_op.input("X")[0]
        )
1535
        mesh_shape = process_mesh.shape
1536
        batch_size_axis = var_dim_mapping[0]
1537 1538 1539 1540 1541
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1542 1543 1544
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1545 1546 1547
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1548 1549 1550 1551 1552
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
        # TODO: trans shape if trans_x or trans_y is True
1553 1554 1555
        comp_desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1556
        processes = dist_op.dist_attr.process_mesh.process_ids
1557 1558 1559
        comp_cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster
        )
1560 1561 1562 1563

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1564 1565
            serial_op.input("Y")[0]
        )[-1]
1566 1567 1568 1569 1570 1571 1572 1573 1574
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1575 1576
            parallel_axis=parallel_axis,
        )
1577
        comm_op_cost_list = build_comm_costs_from_descs(
1578 1579
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1580 1581 1582 1583

        res_cost = [comm_op_cost_list, comp_cost_mapping]
        return res_cost

1584 1585 1586
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1587 1588
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1589
        x_dims_mapping = copy.deepcopy(
1590 1591
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1592
        y_dims_mapping = copy.deepcopy(
1593 1594
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1595 1596 1597
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1598 1599
        if is_dim_shard(x_dims_mapping[-1]):
            return False
1600
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
1601 1602
            y_dims_mapping[-1]
        ):
1603 1604 1605 1606 1607 1608
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1609 1610 1611
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1612 1613 1614 1615 1616 1617 1618 1619 1620
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1621
    def is_auto_compatible(self, dist_op):
1622 1623 1624
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1625
            return False
1626
        if not _is_auto_compatible_for_matmul(dist_op):
1627 1628 1629
            return False
        return True

1630
    def update_dims_mapping(self, dist_op):
1631
        changed = False
1632
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1633 1634 1635 1636
        if dim_changed:
            changed = True
        return changed

1637 1638 1639 1640 1641 1642
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1643
        dist_op_context = ctx.dist_op_context
1644 1645 1646 1647
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1648
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1649 1650 1651
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1652 1653

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1654
        if rank_id not in op_dist_attr.process_mesh.process_ids:
1655 1656 1657
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1658

1659
        # check validation of inputs / outputs
1660 1661
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
1662 1663
                input_name
            )
1664 1665 1666 1667 1668
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
1669 1670
                output_name
            )
1671 1672 1673
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
1674 1675
                output_name
            )
1676

Z
zhaoyingli 已提交
1677
        X_var = main_block._var_recursive(kwargs['X'][0])
1678
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
1679
        Out_var = main_block._var_recursive(kwargs['Out'][0])
1680 1681
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
1682 1683 1684

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
1685 1686
            Weight_var.name
        )[-1]
1687 1688
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
1689 1690 1691 1692 1693 1694 1695
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
1696 1697
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
1698 1699

        parallel_axis = matmul_col_dim_mapping
1700 1701 1702
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1703 1704
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
1705 1706 1707 1708 1709
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
1710 1711 1712
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
1713 1714 1715 1716 1717
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
1718 1719 1720
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1721

1722
        intermediate_var_0 = main_block.create_var(
1723 1724 1725
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
1726 1727 1728 1729
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
1730 1731
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
1732
        # set intermediate_var_0's dist_attr with X_var's dist_attr
1733 1734 1735
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
1736 1737

        check_variable_and_dtype(
1738 1739
            X_var,
            'tensor',
X
xu98bin 已提交
1740
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
1741 1742
            '_c_identity',
        )
1743 1744 1745 1746 1747 1748 1749 1750
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
1751
                OP_ROLE_KEY: src_op.attr('op_role'),
1752 1753
            },
        )
Z
zhaoyingli 已提交
1754 1755
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
1756

1757
        check_variable_and_dtype(
X
xu98bin 已提交
1758 1759 1760 1761
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
1762 1763 1764 1765
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
X
xu98bin 已提交
1766
            ['float16', 'float32', 'float64', 'uint16'],
1767 1768
            'linear',
        )
1769
        attrs = {
1770 1771
            'trans_x': trans_x,
            'trans_y': trans_y,
1772
            OP_ROLE_KEY: src_op.attr('op_role'),
1773
        }
1774
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
1775 1776 1777 1778 1779 1780
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': Out_var},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1781 1782 1783 1784 1785 1786 1787
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1788
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1789 1790 1791 1792 1793
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
1794 1795 1796 1797 1798
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1799 1800
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
1801 1802 1803
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1804 1805 1806 1807 1808
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1809
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1810 1811 1812 1813
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
1814 1815
                    input_varname
                )
Z
zhaoyingli 已提交
1816
                assert input_dist_attr is not None, "dist_attr is {}".format(
1817 1818
                    op_dist_attr
                )
1819
                matmulv2_op_dist_attr.set_input_dist_attr(
1820 1821
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
1822
            else:
Z
zhaoyingli 已提交
1823
                input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
1824
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
1825 1826
                    input_var
                )
1827
                matmulv2_op_dist_attr.set_input_dist_attr(
1828 1829
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
1830 1831 1832
        for output_varname in matmul_v2_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
1833 1834 1835 1836 1837
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
Z
zhaoyingli 已提交
1838
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
1839 1840

        # init param sync
1841
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1842 1843 1844
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1845 1846 1847 1848

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1849 1850 1851 1852 1853


# RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl):
    def __init__(self, name):
1854
        super().__init__(name)
1855
        self._forward_implemented = True
1856
        self._backward_implemented = True
1857

1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
Z
zhaoyingli 已提交
1873

1874
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
1875 1876
            backward_op.input("Y")[0]
        )
1877 1878 1879 1880
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        process_mesh = dist_attr.process_mesh
1881
        processes = process_mesh.process_ids
1882 1883 1884 1885 1886 1887 1888 1889 1890
        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1891 1892
            parallel_axis=parallel_axis,
        )
1893
        comm_op_cost_list = build_comm_costs_from_descs(
1894 1895
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1896 1897 1898
        res.append(comm_op_cost_list)

        # calc comp op cost
1899 1900 1901 1902 1903 1904
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1905 1906 1907 1908 1909
        res.append(cost_mapping)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1910 1911
            backward_op.input("X")[0]
        )
1912
        mesh_shape = process_mesh.shape
1913
        batch_size_axis = var_dim_mapping[0]
1914 1915 1916 1917 1918
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1919 1920 1921
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1922 1923 1924
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1925 1926 1927 1928
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1929 1930 1931
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1932
        processes = dist_op.dist_attr.process_mesh.process_ids
1933 1934 1935
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
1936 1937 1938 1939

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1940 1941
            serial_op.input("Y")[0]
        )[-2]
1942 1943 1944 1945 1946 1947 1948 1949 1950
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1951 1952
            parallel_axis=parallel_axis,
        )
1953 1954

        comm_op_cost_list = build_comm_costs_from_descs(
1955 1956 1957 1958 1959 1960
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
1961 1962 1963 1964
        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1965 1966 1967
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1968 1969
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1970
        x_dims_mapping = copy.deepcopy(
1971 1972
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1973
        y_dims_mapping = copy.deepcopy(
1974 1975
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1976 1977 1978
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1979 1980
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1981
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
1982 1983
            y_dims_mapping[-1]
        ):
1984 1985 1986 1987 1988 1989 1990
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1991 1992 1993
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1994 1995 1996 1997 1998 1999 2000 2001 2002 2003
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

2004
    def is_auto_compatible(self, dist_op):
2005 2006 2007
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2008
            return False
2009
        if not _is_auto_compatible_for_matmul(dist_op):
2010 2011 2012
            return False
        return True

2013
    def update_dims_mapping(self, dist_op):
2014
        changed = False
2015
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
2016 2017 2018 2019
        if dim_changed:
            changed = True
        return changed

2020 2021 2022 2023 2024 2025
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

2026
        dist_op_context = ctx.dist_op_context
2027 2028 2029 2030
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
2031
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2032 2033 2034
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2035 2036

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2037
        if rank_id not in op_dist_attr.process_mesh.process_ids:
2038 2039 2040
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2041

2042
        # check validation of inputs / outputs
2043 2044
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2045 2046
                input_name
            )
2047 2048 2049 2050 2051
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2052 2053
                output_name
            )
2054 2055 2056
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2057 2058
                output_name
            )
2059

Z
zhaoyingli 已提交
2060
        X_var = main_block._var_recursive(kwargs['X'][0])
2061
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2062
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2063 2064
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
2065 2066 2067

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2068 2069
            Weight_var.name
        )[-2]
2070 2071
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2072 2073 2074 2075 2076 2077 2078
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2079 2080
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
2081 2082

        parallel_axis = matmul_row_dim_mapping
2083 2084 2085
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2086 2087
        group = new_process_group(group_ranks)

2088
        check_variable_and_dtype(
X
xu98bin 已提交
2089
            X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
2090 2091
        )
        check_dtype(
X
xu98bin 已提交
2092 2093 2094 2095
            X_var.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
2096
        )
2097
        attrs = {
2098 2099
            'trans_x': trans_x,
            'trans_y': trans_y,
2100
            OP_ROLE_KEY: src_op.attr('op_role'),
2101
        }
2102
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
2103 2104 2105 2106 2107 2108

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2109 2110 2111
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
2112

2113
        intermediate_var_0 = main_block.create_var(
2114 2115 2116
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
2117 2118 2119 2120 2121 2122
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
2123 2124
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
2125
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
2126 2127 2128
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
2129

2130 2131 2132 2133 2134 2135
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
2136 2137
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
2138 2139 2140 2141 2142 2143 2144 2145

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
2146
                'use_model_parallel': True,
2147 2148 2149
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
2150 2151 2152 2153 2154 2155 2156
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2157
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2158 2159 2160 2161
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
2162 2163 2164 2165 2166
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
2167 2168 2169
        output_varname = matmul_v2_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
2170 2171 2172 2173 2174
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
2175 2176 2177 2178 2179
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2180
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2181 2182
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
2183
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
2184 2185
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
2186 2187 2188
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
2189 2190 2191
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
2192 2193 2194 2195 2196 2197 2198 2199
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
2200 2201

        # init param sync
2202
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
2203 2204 2205
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2206 2207 2208 2209

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
2210 2211


2212
# ReplicateParallel
2213
class DistributedMatmulV2Impl2(DistributedOperatorImpl):
2214
    def __init__(self, name):
2215
        super().__init__(name)
2216

2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        process_mesh = dist_attr.process_mesh

        # calc comp op cost
2234 2235 2236
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2237
        processes = process_mesh.process_ids
2238 2239 2240
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
2241 2242 2243 2244
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2245 2246
            backward_op.input("X")[0]
        )
2247
        mesh_shape = process_mesh.shape
2248
        batch_size_axis = var_dim_mapping[0]
2249 2250 2251 2252 2253
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2254 2255 2256
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2257 2258 2259
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2260 2261 2262 2263 2264

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2265 2266 2267
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2268
        processes = dist_op.dist_attr.process_mesh.process_ids
2269 2270 2271
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
2272 2273 2274 2275 2276

        res_cost = [cost_mapping]

        return res_cost

2277 2278 2279
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2280 2281 2282 2283 2284 2285 2286
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
2287
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
2288 2289
            x_dims_mapping[-2]
        ):
2290 2291 2292 2293
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
2294
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
2295 2296
            y_dims_mapping[-2]
        ):
2297 2298 2299
            return False
        return True

2300 2301 2302 2303 2304
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2305 2306 2307 2308 2309
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
2310
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
2311 2312
            out_dims_mapping[-2]
        ):
2313 2314 2315 2316
            return False

        return True

2317
    def is_auto_compatible(self, dist_op):
2318 2319 2320
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2321 2322
            return False

2323
        if not _is_auto_compatible_for_matmul(dist_op):
2324 2325 2326 2327
            return False

        return True

2328
    def update_dims_mapping(self, dist_op):
2329
        changed = False
2330
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
2331 2332 2333 2334
        if dim_changed:
            changed = True
        return changed

2335 2336 2337 2338
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

2339 2340 2341 2342
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

2343 2344

register_distributed_operator_impl(
2345 2346 2347 2348 2349 2350 2351 2352
    "matmul_v2", DistributedMatmulV2Impl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")
)
2353 2354 2355 2356


class DistributedMul(DistributedOperatorImplContainer):
    def __init__(self, op_type):
2357
        super().__init__(op_type)
2358 2359 2360 2361 2362 2363 2364 2365


register_distributed_operator_impl_container(DistributedMul("mul"))


# ColumnParallel
class DistributedMulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
2366
        super().__init__(name)
2367 2368 2369
        self._forward_implemented = True
        self._backward_implemented = True

2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
2386 2387
            backward_op.input("Y")[0]
        )
2388 2389 2390 2391 2392 2393 2394 2395 2396
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
2397 2398 2399
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2400
        process_mesh = dist_attr.process_mesh
2401
        processes = process_mesh.process_ids
2402 2403 2404
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
2417 2418
                parallel_axis=parallel_axis,
            )
2419
            comm_op_cost_list = build_comm_costs_from_descs(
2420 2421 2422 2423 2424 2425
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
2426 2427 2428 2429
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2430 2431
            backward_op.input("X")[0]
        )
2432
        mesh_shape = process_mesh.shape
2433
        batch_size_axis = var_dim_mapping[0]
2434 2435 2436 2437 2438
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2439 2440 2441
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2442 2443 2444
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2445 2446 2447 2448
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2449 2450 2451
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2452
        processes = dist_op.dist_attr.process_mesh.process_ids
2453 2454 2455
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2456 2457 2458 2459

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
2460 2461
            serial_op.input("Y")[0]
        )[-1]
2462 2463 2464 2465 2466 2467 2468 2469
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2470 2471
            parallel_axis=parallel_axis,
        )
2472 2473

        comm_op_cost_list = build_comm_costs_from_descs(
2474 2475
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2476 2477 2478 2479
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

2480 2481 2482 2483 2484 2485 2486 2487 2488
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_shard(x_dims_mapping[-1]):
            return False
2489
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
2490 2491
            y_dims_mapping[-1]
        ):
2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
2511 2512 2513
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2540 2541 2542
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2543 2544

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2545
        if rank_id not in op_dist_attr.process_mesh.process_ids:
2546 2547 2548
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2549 2550 2551 2552

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2553 2554
                input_name
            )
2555 2556 2557 2558 2559
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2560 2561
                output_name
            )
2562 2563 2564
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2565 2566
                output_name
            )
2567

Z
zhaoyingli 已提交
2568
        X_var = main_block._var_recursive(kwargs['X'][0])
2569
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2570
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2571 2572 2573

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
2574 2575 2576 2577 2578 2579 2580
            Weight_var.name
        )[-1]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
2581 2582
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
2583 2584

        parallel_axis = matmul_col_dim_mapping
2585 2586 2587
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2588 2589 2590 2591 2592 2593 2594
        group = new_process_group(group_ranks)

        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
2595 2596 2597
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
2598 2599 2600 2601 2602
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2603 2604 2605
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
2606 2607

        intermediate_var_0 = main_block.create_var(
2608 2609 2610
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
2611 2612 2613 2614
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
2615 2616
            stop_gradient=X_var.stop_gradient,
        )
2617
        # set intermediate_var_0's dist_attr with X_var's dist_attr
2618 2619 2620
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
2621 2622

        check_variable_and_dtype(
2623 2624
            X_var,
            'tensor',
X
xu98bin 已提交
2625
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
2626 2627
            '_c_identity',
        )
2628 2629 2630 2631 2632 2633 2634 2635
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
2636 2637 2638
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
2639 2640 2641
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)

2642
        check_variable_and_dtype(
X
xu98bin 已提交
2643 2644 2645 2646
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
2647 2648 2649 2650
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
X
xu98bin 已提交
2651
            ['float16', 'float32', 'float64', 'uint16'],
2652 2653
            'linear',
        )
2654 2655 2656
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
2657
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
2658
            OP_ROLE_KEY: src_op.attr('op_role'),
2659
        }
2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671
        inputs = {'X': intermediate_var_0, 'Y': Weight_var}

        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            if var_name == "X":
                var = X_var
            else:
                var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
2672 2673 2674
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
2675 2676 2677
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

2678 2679 2680
        mul_op = main_block.append_op(
            type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
2681 2682 2683
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

2684 2685 2686 2687 2688
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

2689 2690 2691 2692 2693 2694 2695 2696 2697 2698
        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
2699 2700 2701 2702 2703
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
2704 2705
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
2706 2707 2708
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
2709 2710 2711 2712 2713 2714 2715 2716 2717 2718
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
2719 2720
                    input_varname
                )
2721
                assert input_dist_attr is not None, "dist_attr is {}".format(
2722 2723
                    op_dist_attr
                )
2724
                matmulv2_op_dist_attr.set_input_dist_attr(
2725 2726
                    input_varname, input_dist_attr
                )
2727
            else:
Z
zhaoyingli 已提交
2728
                input_var = main_block._var_recursive(input_varname)
2729
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
2730 2731
                    input_var
                )
2732
                matmulv2_op_dist_attr.set_input_dist_attr(
2733 2734
                    input_varname, tensor_dist_attr
                )
2735 2736 2737
        for output_varname in mul_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
2738 2739 2740 2741 2742
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
2743 2744 2745 2746
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
2747 2748 2749
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2750 2751 2752 2753 2754 2755 2756 2757 2758

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# RowParallel
class DistributedMulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
2759
        super().__init__(name)
2760 2761 2762
        self._forward_implemented = True
        self._backward_implemented = True

2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        process_mesh = dist_attr.process_mesh
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
2780 2781
            backward_op.input("Y")[0]
        )
2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2794 2795
            parallel_axis=parallel_axis,
        )
2796
        processes = process_mesh.process_ids
2797
        comm_op_cost_list = build_comm_costs_from_descs(
2798 2799
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2800 2801 2802
        res.append(comm_op_cost_list)

        # calc comp op cost
2803 2804 2805 2806 2807 2808
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2809 2810 2811 2812
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2813 2814
            backward_op.input("X")[0]
        )
2815
        mesh_shape = process_mesh.shape
2816
        batch_size_axis = var_dim_mapping[0]
2817 2818 2819 2820 2821
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2822 2823 2824
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2825 2826 2827
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2828 2829 2830 2831
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2832 2833 2834
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2835
        processes = dist_op.dist_attr.process_mesh.process_ids
2836 2837 2838
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2839 2840 2841 2842

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
2843 2844
            serial_op.input("Y")[0]
        )[-2]
2845 2846 2847 2848 2849 2850 2851 2852 2853
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2854 2855
            parallel_axis=parallel_axis,
        )
2856 2857 2858

        # print("dist_matmul.py dist_op: ", dist_op)
        comm_op_cost_list = build_comm_costs_from_descs(
2859 2860 2861 2862 2863 2864
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
2865 2866 2867 2868 2869

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

2870 2871 2872 2873 2874 2875 2876 2877 2878
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
2879
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
2880 2881
            y_dims_mapping[-1]
        ):
2882 2883 2884 2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
2903 2904 2905
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927 2928 2929 2930 2931
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2932 2933 2934
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2935 2936

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2937
        if rank_id not in op_dist_attr.process_mesh.process_ids:
2938 2939 2940
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2941 2942 2943 2944

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2945 2946
                input_name
            )
2947 2948 2949 2950 2951
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2952 2953
                output_name
            )
2954 2955 2956
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2957 2958
                output_name
            )
2959

Z
zhaoyingli 已提交
2960
        X_var = main_block._var_recursive(kwargs['X'][0])
2961
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2962
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2963 2964 2965

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2966 2967 2968 2969 2970 2971 2972
            Weight_var.name
        )[-2]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2973 2974
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
2975 2976

        parallel_axis = matmul_row_dim_mapping
2977 2978 2979
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2980 2981
        group = new_process_group(group_ranks)

2982
        check_variable_and_dtype(
X
xu98bin 已提交
2983
            X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
2984 2985
        )
        check_dtype(
X
xu98bin 已提交
2986 2987 2988 2989
            X_var.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
2990
        )
2991 2992 2993
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
2994
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
2995
            OP_ROLE_KEY: src_op.attr('op_role'),
2996 2997 2998 2999 3000 3001 3002 3003
        }
        inputs = {'X': X_var, 'Y': Weight_var}

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
3004 3005 3006
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
3007 3008

        intermediate_var_0 = main_block.create_var(
3009 3010 3011
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
3012 3013 3014 3015 3016 3017
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
3018 3019
            need_check_feed=Out_var.desc.need_check_feed(),
        )
3020
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
3021 3022 3023
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
3024

3025 3026 3027 3028 3029 3030 3031
        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
3032 3033 3034
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
3035 3036 3037
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

3038 3039 3040 3041 3042 3043
        mul_op = main_block.append_op(
            type='mul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
3044

3045 3046 3047
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)

3048 3049 3050 3051 3052
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

3053 3054 3055 3056 3057 3058 3059
        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
3060
                'use_model_parallel': True,
3061 3062 3063
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
3064

3065 3066 3067 3068 3069 3070 3071 3072 3073 3074 3075 3076
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
3077 3078 3079 3080 3081
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
3082 3083 3084
        output_varname = mul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
3085 3086 3087 3088 3089
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
3090 3091 3092 3093 3094 3095 3096 3097
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
3098
            input_var = main_block._var_recursive(input_varname)
3099 3100
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
3101 3102 3103
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
3104 3105 3106
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
3107 3108 3109 3110 3111 3112 3113 3114
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
3115 3116 3117

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
3118 3119 3120
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
3121 3122 3123 3124 3125 3126 3127 3128 3129

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# ReplicateParallel
class DistributedMulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
3130
        super().__init__(name)
3131

3132 3133 3134 3135 3136 3137 3138 3139 3140 3141 3142 3143 3144 3145 3146 3147
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block

        # calc comp op cost
3148 3149 3150
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3151
        process_mesh = dist_attr.process_mesh
3152
        processes = process_mesh.process_ids
3153 3154 3155
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
3156 3157 3158 3159
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
3160 3161
            backward_op.input("X")[0]
        )
3162
        mesh_shape = process_mesh.shape
3163
        batch_size_axis = var_dim_mapping[0]
3164 3165 3166 3167 3168
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
3169 3170 3171
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
3172 3173 3174
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
3175 3176 3177 3178 3179

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
3180 3181 3182
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3183
        processes = dist_op.dist_attr.process_mesh.process_ids
3184 3185 3186
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
3187 3188 3189 3190

        res_cost = [cost_mapping]
        return res_cost

3191 3192 3193 3194 3195 3196 3197 3198 3199 3200
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
3201
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
3202 3203
            x_dims_mapping[-2]
        ):
3204 3205 3206
            return False
        if is_dim_shard(y_dims_mapping[-1]):
            return False
3207
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
3208 3209
            y_dims_mapping[-2]
        ):
3210 3211 3212 3213 3214 3215 3216 3217 3218 3219 3220 3221 3222
            return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
3223
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
3224 3225
            out_dims_mapping[-2]
        ):
3226 3227 3228 3229 3230
            return False

        return True

    def is_auto_compatible(self, dist_op):
3231 3232 3233
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
3234 3235 3236 3237 3238 3239 3240 3241 3242 3243 3244 3245 3246 3247 3248 3249 3250 3251 3252 3253 3254 3255 3256
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


3257 3258 3259
register_distributed_operator_impl(
    "mul", DistributedMulImpl0("column_parallel")
)
3260
register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel"))
3261 3262 3263
register_distributed_operator_impl(
    "mul", DistributedMulImpl2("replicate_parallel")
)