dist_matmul.py 117.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
import copy
C
caozhou 已提交
16

17 18 19 20
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
    AllreduceSumOpCost,
    IdentityOpCost,
)
21
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
22 23 24
from paddle.fluid import core, unique_name
from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype

25
from ..cost import (
26 27 28 29 30 31 32
    MatmulGradOpCost,
    MatmulOpCost,
    MatmulV2GradOpCost,
    MatmulV2OpCost,
    MulGradOpCost,
    MulOpCost,
    build_comm_costs_from_descs,
33
    build_comm_desc_from_dist_op,
34 35
    build_comp_costs_from_descs,
    build_comp_desc_from_dist_op,
36 37
    build_dp_costs,
)
38 39 40 41 42 43 44 45 46 47 48
from ..dist_attribute import OperatorDistributedAttribute
from ..process_group import new_process_group
from ..utils import (
    _get_comm_group,
    _get_corresponding_rank,
    compute_compatible_and_update_dim_mapping,
    compute_compatible_dims_mapping,
    is_dim_replicate,
    is_dim_shard,
    is_valid_list_index,
    set_dist_op_desc_original_id,
49
)
50 51 52 53 54 55 56 57 58 59 60
from .common import (
    DistributedOperatorImpl,
    DistributedOperatorImplContainer,
    gradient_synchronization,
    infer_shape,
    is_parameter_related,
    register_distributed_operator_impl,
    register_distributed_operator_impl_container,
    set_comm_op_dist_attr_for_program,
)
from .dist_default import DistributedDefaultImpl0
61 62


63 64
def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping):
    if trans_x:
65 66 67 68
        x_dims_mapping[-1], x_dims_mapping[-2] = (
            x_dims_mapping[-2],
            x_dims_mapping[-1],
        )
69
    if trans_y:
70 71 72 73
        y_dims_mapping[-1], y_dims_mapping[-2] = (
            y_dims_mapping[-2],
            y_dims_mapping[-1],
        )
74 75


76
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
77
    dist_op_desc = block.append_op(type='nop').desc
78
    dist_op_desc.copy_from(src_op.desc)
79
    set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
80 81 82 83 84 85 86 87 88 89
    for input_name in src_op.desc.input_names():
        assert input_name in kwargs
        dist_op_desc.set_input(input_name, kwargs[input_name])
    for output_name in src_op.desc.output_names():
        assert input_name in kwargs
        dist_op_desc.set_output(output_name, kwargs[output_name])

    return dist_op_desc


90
def _update_dims_mapping_for_matmul(dist_op):
91
    changed = False
92 93
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
94 95 96
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
C
caozhou 已提交
97 98 99 100 101 102 103 104
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
105 106 107 108 109 110 111 112 113
    x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
    y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
    out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
C
caozhou 已提交
114
        assert trans_x is False
115
        x_dims_mapping.insert(0, -1)
C
caozhou 已提交
116
        out_dims_mapping.insert(out_dims_mapping_len - 1, 0)
117
    if y_dims_mapping_len == 1:
C
caozhou 已提交
118
        assert trans_y is False
119
        y_dims_mapping.insert(1, -1)
C
caozhou 已提交
120
        out_dims_mapping.insert(out_dims_mapping_len, 0)
121

122 123
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)

C
caozhou 已提交
124 125 126
    new_x_dims_mapping_len = len(x_dims_mapping)
    new_y_dims_mapping_len = len(y_dims_mapping)
    new_out_dims_mapping_len = len(out_dims_mapping)
127
    # Deal with dim > 2 and take care of broadcasting
C
caozhou 已提交
128
    if new_out_dims_mapping_len > 2:
129 130 131 132
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

C
caozhou 已提交
133
        for i in range(new_out_dims_mapping_len - new_x_dims_mapping_len):
134
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
135
        for i in range(new_x_dims_mapping_len - 2):
136 137
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

C
caozhou 已提交
138
        for i in range(new_out_dims_mapping_len - new_y_dims_mapping_len):
139
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
140
        for i in range(new_y_dims_mapping_len - 2):
141 142
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

C
caozhou 已提交
143
        for i in range(new_out_dims_mapping_len - 2):
144 145
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

146 147 148 149 150 151 152
        compatible_dims_mapping = compute_compatible_dims_mapping(
            [
                broadcast_x_dims_mapping,
                broadcast_y_dims_mapping,
                broadcast_out_dims_mapping,
            ]
        )
153
        if compatible_dims_mapping is None:
154 155 156
            trans_x_y_dims_mapping(
                trans_x, trans_y, x_dims_mapping, y_dims_mapping
            )
157
            return False
158

C
caozhou 已提交
159 160
        for i in range(new_x_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_x_dims_mapping_len)
161 162 163 164
            if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                x_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
165 166
        for i in range(new_y_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_y_dims_mapping_len)
167 168 169 170
            if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                y_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
171
        for i in range(new_out_dims_mapping_len - 2):
172 173 174 175
            if out_dims_mapping[i] != compatible_dims_mapping[i]:
                out_dims_mapping[i] = compatible_dims_mapping[i]
                changed = True

176
    # The following which uses negative index can be work
177 178
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
    dim_changed = compute_compatible_and_update_dim_mapping(
179 180
        [x_dims_mapping, y_dims_mapping], [-1, -2]
    )
181 182 183 184
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
185 186
        [x_dims_mapping, out_dims_mapping], [-2, -2]
    )
187 188 189 190
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
191 192
        [y_dims_mapping, out_dims_mapping], [-1, -1]
    )
193 194 195
    if dim_changed:
        changed = True

196
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
C
caozhou 已提交
197

198
    # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
199 200
    if x_dims_mapping_len == 1:
        x_dims_mapping.pop(0)
C
caozhou 已提交
201
        out_dims_mapping.pop(out_dims_mapping_len - 1)
202 203
    if y_dims_mapping_len == 1:
        y_dims_mapping.pop(1)
C
caozhou 已提交
204
        out_dims_mapping.pop(out_dims_mapping_len)
205 206 207 208 209 210 211 212

    assert len(x_dims_mapping) == x_dims_mapping_len
    assert len(y_dims_mapping) == y_dims_mapping_len
    assert len(out_dims_mapping) == out_dims_mapping_len

    return changed


213 214 215 216 217 218
def _is_auto_compatible_for_matmul(dist_op):
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
219 220 221 222 223 224 225 226 227
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')

228 229 230 231
    # Deep copy these dims_mappings for keeping them unchanged.
    x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
    y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
    out_dims_mapping = copy.deepcopy(
232 233
        op_dist_attr.get_output_dims_mapping(out_name)
    )
234 235 236 237 238 239 240 241 242 243
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
        x_dims_mapping.insert(0, -1)
    if y_dims_mapping_len == 1:
        y_dims_mapping.insert(1, -1)

244
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
245

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    # Deal with dim > 2 and take care of broadcasting
    if out_dims_mapping_len > 2:
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

        for i in range(out_dims_mapping_len - x_dims_mapping_len):
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
        for i in range(x_dims_mapping_len - 2):
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

        for i in range(out_dims_mapping_len - y_dims_mapping_len):
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
        for i in range(y_dims_mapping_len - 2):
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

        for i in range(out_dims_mapping_len - 2):
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

265 266 267
        is_same = (broadcast_x_dims_mapping == broadcast_y_dims_mapping) and (
            broadcast_x_dims_mapping == broadcast_out_dims_mapping
        )
268 269 270 271 272
        if not is_same:
            return False

    # The following which uses negative index can be work
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
273
    is_same = x_dims_mapping[-1] == y_dims_mapping[-2]
274 275 276
    if not is_same:
        return False

277
    is_same = x_dims_mapping[-2] == out_dims_mapping[-2]
278 279 280
    if not is_same:
        return False

281
    is_same = y_dims_mapping[-1] == out_dims_mapping[-1]
282 283 284 285 286 287
    if not is_same:
        return False

    return True


288 289 290 291
def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):

    # by now the backward function only insert the gradient allreduce for dist op itself

292
    dist_op_context = ctx.dist_op_context
293 294 295
    main_block = dist_op_context.work_block
    backward_op = dist_op_context.cur_src_op
    rank_id = dist_op_context.rank_id
296
    dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
297 298 299
    assert (
        dist_attr is not None
    ), "backward op [{}] don't have dist attribute !".format(str(backward_op))
300 301

    # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
302
    if rank_id not in dist_attr.process_mesh.process_ids:
303
        rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
304 305 306 307 308 309

    assert 'Y' in kwargs, "input [{}] is not given".format('Y')
    assert 'X' in kwargs, "input [{}] is not given".format('X')
    assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
    assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
    assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
310 311 312
    assert (
        len(kwargs['Y']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
313
        kwargs['Y']
314 315 316 317
    )
    assert (
        len(kwargs['X']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
318
        kwargs['X']
319 320 321 322 323 324 325 326 327
    )
    assert (
        len(kwargs['Out@GRAD']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
        kwargs['Out']
    )
    assert (
        len(kwargs['Y@GRAD']) == 1
    ), "row_parallel_embedding output Ids take 1 variable but got {}".format(
328
        kwargs['Y@GRAD']
329
    )
330

Z
zhaoyingli 已提交
331
    X_var = main_block._var_recursive(kwargs['X'][0])
332
    Y_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
333 334
    Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0])
    Y_grad = main_block._var_recursive(kwargs['Y@GRAD'][0])
335

J
JZ-LIANG 已提交
336 337 338
    assert not is_parameter_related(
        X_var.name, main_block
    ), "left operand(X) [{}] of dist matmul should not be parameter".format(
339 340
        X_var.name
    )
341

342
    X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
343
    Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
344 345
    process_mesh_shape = dist_attr.process_mesh.shape
    process_mesh_group = dist_attr.process_mesh.process_ids
346 347 348 349 350 351 352 353 354 355 356 357 358

    trans_x = None
    trans_y = None
    if backward_op.desc.type() == "matmul_v2_grad":
        trans_x = backward_op.desc.attr('trans_x')
        trans_y = backward_op.desc.attr('trans_y')
    elif backward_op.desc.type() == "matmul_grad":
        trans_x = backward_op.desc.attr('transpose_X')
        trans_y = backward_op.desc.attr('transpose_Y')

    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

359 360 361 362
    # assert len(
    #     Y_var_dim_mapping
    # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
    #     Y_var.name, Y_var_dim_mapping)
363 364 365 366 367 368
    Y_var_partitioned = False
    for dim in Y_var_dim_mapping:
        if dim >= 0 and process_mesh_shape[dim] > 0:
            Y_var_partitioned = True
            break

J
JZ-LIANG 已提交
369
    if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:
370 371 372 373 374 375 376

        if Y_var_dim_mapping[0] >= 0:
            # row parallel: c_identity + matmul
            assert Y_var_dim_mapping[1] < 0
            parallel_axis = Y_var_dim_mapping[0]

            check_variable_and_dtype(
377 378
                Out_grad,
                'tensor',
379
                ['float16', 'float32', 'float64', 'int32', 'int64'],
380 381
                '_c_identity',
            )
382 383

            intermediate_var_0 = main_block.create_var(
384 385 386 387
                name=unique_name.generate_with_ignorable_key(
                    ".".join(["c_identity", 'tmp'])
                )
                + "@GRAD",
388 389 390 391
                dtype=Out_grad.dtype,
                shape=Out_grad.shape,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
392 393
                stop_gradient=Out_grad.stop_gradient,
            )
394 395 396 397

            # copy X_var's dist_attr to intermediate_var_0's dist_attr
            out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
            assert out_grad_dist_attr is not None
398 399 400
            ctx.set_tensor_dist_attr_for_program(
                intermediate_var_0, out_grad_dist_attr
            )
401

402 403 404
            group_ranks = _get_comm_group(
                process_mesh_group, process_mesh_shape, parallel_axis, rank_id
            )
405 406 407 408 409 410 411 412 413 414
            group = new_process_group(group_ranks)
            c_identity_op = main_block.append_op(
                type='c_identity',
                inputs={'X': [Out_grad]},
                outputs={'Out': intermediate_var_0},
                attrs={
                    'ring_id': group.id,
                    'use_calc_stream': True,
                    'use_model_parallel': True,
                    OP_ROLE_KEY: OpRole.Backward,
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
                },
            )
            check_variable_and_dtype(
                intermediate_var_0,
                'x',
                ['float16', 'float32', 'float64'],
                'linear',
            )
            check_dtype(
                intermediate_var_0.dtype,
                'dtype',
                ['float16', 'float32', 'float64'],
                'linear',
            )
            set_comm_op_dist_attr_for_program(
                c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
            )
432 433 434 435

            new_kwargs = copy.deepcopy(kwargs)
            new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
            matmul_op_desc = copy_op_with_new_input_output(
436 437
                ctx, main_block, backward_op, **new_kwargs
            )
438 439 440 441 442 443 444 445 446 447
        else:
            # col parallel: matmul + allreduce
            assert Y_var_dim_mapping[0] < 0
            parallel_axis = Y_var_dim_mapping[1]
            new_kwargs = copy.deepcopy(kwargs)

            # NOTE (JZ-LIANG) should allow left operand be empty for matmul grad
            has_x_grad = len(kwargs['X@GRAD']) > 0
            if has_x_grad:
                assert len(kwargs['X@GRAD']) == 1
Z
zhaoyingli 已提交
448
                X_grad = main_block._var_recursive(kwargs['X@GRAD'][0])
449
                intermediate_var_0 = main_block.create_var(
450 451 452 453
                    name=unique_name.generate_with_ignorable_key(
                        ".".join(["c_identity", 'tmp'])
                    )
                    + "@GRAD",
454 455 456 457
                    dtype=X_grad.dtype,
                    shape=X_grad.shape,
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    persistable=False,
458 459
                    stop_gradient=X_grad.stop_gradient,
                )
460 461 462

                X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
                assert X_grad_dist_attr is not None
463 464 465
                ctx.set_tensor_dist_attr_for_program(
                    intermediate_var_0, X_grad_dist_attr
                )
466 467 468
                new_kwargs['X@GRAD'] = [intermediate_var_0.name]

            matmul_op_desc = copy_op_with_new_input_output(
469 470
                ctx, main_block, backward_op, **new_kwargs
            )
471 472 473

            # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
            if has_x_grad:
474 475 476 477 478 479
                group_ranks = _get_comm_group(
                    process_mesh_group,
                    process_mesh_shape,
                    parallel_axis,
                    rank_id,
                )
480 481 482 483 484 485 486 487 488
                group = new_process_group(group_ranks)
                c_allreduce_sum_op = main_block.append_op(
                    type='c_allreduce_sum',
                    inputs={'X': [intermediate_var_0.name]},
                    outputs={'Out': kwargs['X@GRAD']},
                    attrs={
                        'ring_id': group.id,
                        'use_calc_stream': True,
                        'use_model_parallel': True,
489 490 491 492 493 494 495 496 497
                        OP_ROLE_KEY: OpRole.Backward,
                    },
                )
                set_comm_op_dist_attr_for_program(
                    c_allreduce_sum_op,
                    dist_attr.process_mesh,
                    X_grad_dist_attr,
                    ctx,
                )
498 499
    else:
        # replicate
500 501 502
        matmul_op_desc = copy_op_with_new_input_output(
            ctx, main_block, backward_op, **kwargs
        )
503

504 505 506 507 508 509 510
    # data parallel gradient synchronization
    act_grad_names = [X_var.name]

    out_grad_names = []
    if is_parameter_related(Y_var.name, main_block):
        out_grad_names = [kwargs['Y@GRAD'][0]]

511 512 513
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)

514 515 516
    gradient_synchronization(
        ctx, backward_op, act_grad_names, out_grad_names, rank_id
    )
517

518 519 520 521 522
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

523

524
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
525

526 527
    if Weight_var.name in dist_op_context.already_init_sync_vars:
        return
528
    assert startup_block.has_var(Weight_var.name)
529
    dist_op_context.already_init_sync_vars.add(Weight_var.name)
530
    param = startup_block.var(Weight_var.name)
531 532 533
    param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
    process_mesh = param_dist_attr.process_mesh
    dim_mapping = param_dist_attr.dims_mapping
534

535
    for axis, size in enumerate(process_mesh.shape):
536 537 538
        if size <= 1 or axis in dim_mapping:
            pass
        else:
539
            group_ranks = _get_comm_group(
540
                process_mesh.process_ids, process_mesh.shape, axis, rank_id
541
            )
542 543
            sync_group = new_process_group(group_ranks)

544 545 546 547 548 549 550 551 552 553 554
            startup_block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': sync_group.id,
                    'root': 0,
                    'use_calc_stream': True,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
555 556


557
class DistributedMatmul(DistributedOperatorImplContainer):
558
    def __init__(self, op_type):
559
        super().__init__(op_type)
560 561


562
register_distributed_operator_impl_container(DistributedMatmul("matmul"))
563 564 565 566 567


# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
568
        super().__init__(name)
569
        self._forward_implemented = True
570
        self._backward_implemented = True
571

C
caozhou 已提交
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
588 589
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
590 591 592 593 594 595 596 597 598
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
599 600 601
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
602
        process_mesh = dist_attr.process_mesh
603
        processes = process_mesh.process_ids
604 605 606
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
607 608 609 610 611 612 613 614 615 616 617 618
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
619 620
                parallel_axis=parallel_axis,
            )
C
caozhou 已提交
621
            comm_op_cost_list = build_comm_costs_from_descs(
622 623 624 625 626 627
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
C
caozhou 已提交
628 629 630 631
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
632 633
            backward_op.input("X")[0]
        )
634
        mesh_shape = process_mesh.shape
C
caozhou 已提交
635
        batch_size_axis = var_dim_mapping[0]
636 637 638 639 640
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
641 642 643
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
644 645 646
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
647 648 649 650
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
651 652 653
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
654
        processes = dist_op.dist_attr.process_mesh.process_ids
655 656 657
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
658 659 660 661

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
662 663
            serial_op.input("Y")[0]
        )[-1]
C
caozhou 已提交
664 665 666 667 668 669 670 671
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
672 673
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
674 675

        comm_op_cost_list = build_comm_costs_from_descs(
676 677
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
678 679 680 681
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

682 683 684
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
685 686
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
687
        x_dims_mapping = copy.deepcopy(
688 689
            op_dist_attr.get_input_dims_mapping(x_name)
        )
690
        y_dims_mapping = copy.deepcopy(
691 692
            op_dist_attr.get_input_dims_mapping(y_name)
        )
693 694 695
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
696 697
        if is_dim_shard(x_dims_mapping[-1]):
            return False
698
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
699 700
            y_dims_mapping[-1]
        ):
701 702 703 704 705 706
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

707 708 709
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
710 711 712 713 714 715 716 717 718
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

719
    def is_auto_compatible(self, dist_op):
720 721 722
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
723
            return False
724
        if not _is_auto_compatible_for_matmul(dist_op):
725 726 727
            return False
        return True

728
    def update_dims_mapping(self, dist_op):
729
        changed = False
730
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
731 732 733 734
        if dim_changed:
            changed = True
        return changed

735 736 737 738 739 740
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

741
        dist_op_context = ctx.dist_op_context
742 743 744 745
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
746
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
747 748 749
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
750 751

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
752
        if rank_id not in op_dist_attr.process_mesh.process_ids:
753 754 755
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
756

757
        # check validation of inputs / outputs
758 759
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
760 761
                input_name
            )
762 763 764 765 766
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
767 768
                output_name
            )
769 770 771
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
772 773
                output_name
            )
774

Z
zhaoyingli 已提交
775 776 777
        X_var = main_block._var_recursive(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block._var_recursive(kwargs['Out'][0])
778 779
        trans_x = src_op.attr("transpose_X")
        trans_y = src_op.attr("transpose_Y")
780 781 782

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
783 784
            Weight_var.name
        )[-1]
785 786
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
787 788 789 790 791 792 793
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
794 795
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
796 797

        parallel_axis = matmul_col_dim_mapping
798 799 800
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
801 802
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
803 804 805 806 807
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
808 809 810
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
811 812 813 814 815
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
816 817 818
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
819

820
        intermediate_var_0 = main_block.create_var(
821 822 823
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
824 825 826 827
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
828 829
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
830
        # set intermediate_var_0's dist_attr with X_var's dist_attr
831 832 833
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
834 835

        check_variable_and_dtype(
836 837 838 839 840
            X_var,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            '_c_identity',
        )
841 842 843 844 845 846 847 848 849

        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
850 851 852
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
853 854
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
855

856 857 858 859 860 861 862 863 864
        check_variable_and_dtype(
            intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64'],
            'linear',
        )
865
        attrs = {
866 867
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
868
            'alpha': 1,
869
            OP_ROLE_KEY: src_op.attr('op_role'),
870 871
        }
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
872 873 874
        matmul_op = main_block.append_op(
            type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
Z
zhaoyingli 已提交
875 876 877 878 879 880 881
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
882
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
883 884 885 886 887
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
888 889 890 891 892
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
893 894
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
895 896 897
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
898 899 900 901 902 903
        # set op dist attr
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
904
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
905 906 907 908 909
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        for input_varname in matmul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
910 911
                    input_varname
                )
Z
zhaoyingli 已提交
912
                assert input_dist_attr is not None, "dist_attr is {}".format(
913 914 915 916 917
                    op_dist_attr
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
918
            else:
Z
zhaoyingli 已提交
919
                input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
920
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
921 922 923 924 925
                    input_var
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
926 927 928 929
        # output
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
        assert output_dist_attr is not None, "dist_attr is {}".format(
930 931 932 933 934
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
935 936
        # set op dist attr
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
937 938

        # init param sync
939
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
940 941 942
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
943 944 945 946

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
947

948 949 950 951

# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
952
        super().__init__(name)
953
        self._forward_implemented = True
954
        self._backward_implemented = True
955

C
caozhou 已提交
956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
972 973
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
974 975 976 977 978 979 980 981 982 983 984 985
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
986 987
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
988
        process_mesh = dist_attr.process_mesh
989
        processes = process_mesh.process_ids
C
caozhou 已提交
990
        comm_op_cost_list = build_comm_costs_from_descs(
991 992
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
993 994 995
        res.append(comm_op_cost_list)

        # calc comp op cost
996 997 998 999 1000 1001
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1002 1003 1004 1005
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1006 1007
            backward_op.input("X")[0]
        )
1008
        mesh_shape = process_mesh.shape
C
caozhou 已提交
1009
        batch_size_axis = var_dim_mapping[0]
1010 1011 1012 1013 1014
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1015 1016 1017
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1018 1019 1020
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1021 1022 1023 1024
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1025 1026 1027
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1028
        processes = dist_op.dist_attr.process_mesh.process_ids
1029 1030 1031
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1032 1033 1034 1035

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1036 1037
            serial_op.input("Y")[0]
        )[-2]
C
caozhou 已提交
1038 1039 1040 1041 1042 1043 1044 1045 1046
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1047 1048
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
1049 1050

        comm_op_cost_list = build_comm_costs_from_descs(
1051 1052 1053 1054 1055 1056
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
C
caozhou 已提交
1057 1058 1059 1060 1061

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1062 1063 1064
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1065 1066
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1067
        x_dims_mapping = copy.deepcopy(
1068 1069
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1070
        y_dims_mapping = copy.deepcopy(
1071 1072
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1073 1074 1075
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1076 1077
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1078
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
1079 1080
            y_dims_mapping[-1]
        ):
1081 1082 1083 1084 1085 1086 1087
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1088 1089 1090
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1091 1092 1093 1094 1095 1096 1097 1098 1099 1100
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1101
    def is_auto_compatible(self, dist_op):
1102 1103 1104
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1105
            return False
1106
        if not _is_auto_compatible_for_matmul(dist_op):
1107 1108 1109
            return False
        return True

1110
    def update_dims_mapping(self, dist_op):
1111
        changed = False
1112
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1113 1114 1115 1116
        if dim_changed:
            changed = True
        return changed

1117 1118 1119 1120 1121 1122
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1123
        dist_op_context = ctx.dist_op_context
1124 1125 1126 1127
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1128
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1129 1130 1131
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1132 1133

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1134
        if rank_id not in op_dist_attr.process_mesh.process_ids:
1135 1136 1137
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1138

1139
        # check validation of inputs / outputs
1140 1141
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
1142 1143
                input_name
            )
1144 1145 1146 1147 1148
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
1149 1150
                output_name
            )
1151 1152 1153
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
1154 1155
                output_name
            )
1156

Z
zhaoyingli 已提交
1157 1158 1159
        X_var = main_block._var_recursive(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block._var_recursive(kwargs['Out'][0])
1160 1161
        trans_x = src_op.attr('transpose_X')
        trans_y = src_op.attr('transpose_Y')
1162 1163 1164

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
1165 1166
            Weight_var.name
        )[-2]
1167 1168
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
1169 1170 1171 1172 1173 1174 1175
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
1176 1177
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
1178 1179

        parallel_axis = matmul_row_dim_mapping
1180 1181 1182
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1183 1184
        group = new_process_group(group_ranks)

1185 1186 1187 1188 1189 1190
        check_variable_and_dtype(
            X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
        )
1191
        attrs = {
1192 1193
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
1194
            'alpha': 1,
1195
            OP_ROLE_KEY: src_op.attr('op_role'),
1196 1197
        }
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
1198 1199 1200 1201 1202 1203

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
1204 1205 1206
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1207

1208
        intermediate_var_0 = main_block.create_var(
1209 1210 1211
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
1212 1213 1214 1215 1216 1217
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
1218 1219
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
1220
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
1221 1222 1223
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
1224

1225 1226 1227 1228 1229 1230
        matmul_op = main_block.append_op(
            type='matmul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1231 1232
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
1233 1234 1235 1236 1237 1238 1239 1240

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
1241
                'use_model_parallel': True,
1242 1243 1244
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
1245 1246 1247 1248 1249 1250 1251
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1252
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1253 1254 1255 1256
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
1257 1258 1259 1260 1261
                op_dist_attr
            )
            matmul_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
1262 1263 1264
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
1265 1266 1267 1268 1269
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
1270 1271 1272 1273 1274
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1275
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1276 1277
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
1278
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
1279 1280
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
1281 1282 1283
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
1284 1285 1286
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
1287 1288 1289 1290 1291 1292 1293 1294
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
1295 1296

        # init param sync
1297
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1298 1299 1300
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1301 1302 1303 1304

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1305

1306

1307
# ReplicateParallel
1308 1309
class DistributedMatmulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
1310
        super().__init__(name)
1311

C
caozhou 已提交
1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block

        # calc comp op cost
1328 1329 1330
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
1331
        process_mesh = dist_attr.process_mesh
1332
        processes = process_mesh.process_ids
1333 1334 1335
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1336 1337 1338 1339
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1340 1341
            backward_op.input("X")[0]
        )
1342
        mesh_shape = process_mesh.shape
C
caozhou 已提交
1343
        batch_size_axis = var_dim_mapping[0]
1344 1345 1346 1347 1348
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1349 1350 1351
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1352 1353 1354
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1355 1356 1357 1358 1359

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1360 1361 1362
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1363
        processes = dist_op.dist_attr.process_mesh.process_ids
1364 1365 1366
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1367 1368 1369 1370

        res_cost = [cost_mapping]
        return res_cost

1371 1372 1373
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1374 1375 1376 1377 1378 1379 1380
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
1381
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
1382 1383
            x_dims_mapping[-2]
        ):
1384 1385 1386 1387
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
1388
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
1389 1390
            y_dims_mapping[-2]
        ):
1391 1392 1393 1394
            return False

        return True

1395 1396 1397
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1398 1399 1400 1401 1402
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
1403
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
1404 1405
            out_dims_mapping[-2]
        ):
1406 1407 1408 1409
            return False

        return True

1410
    def is_auto_compatible(self, dist_op):
1411 1412 1413
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1414 1415
            return False

1416
        if not _is_auto_compatible_for_matmul(dist_op):
1417 1418 1419 1420
            return False

        return True

1421
    def update_dims_mapping(self, dist_op):
1422
        changed = False
1423
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1424 1425 1426 1427
        if dim_changed:
            changed = True
        return changed

1428 1429 1430 1431
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

1432 1433 1434 1435
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

1436

1437 1438 1439 1440 1441 1442 1443 1444 1445
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl2("replicate_parallel")
)
1446 1447


1448
class DistributedMatmulV2(DistributedOperatorImplContainer):
1449
    def __init__(self, op_type):
1450
        super().__init__(op_type)
1451 1452


1453
register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
1454 1455


1456 1457 1458
# ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl):
    def __init__(self, name):
1459
        super().__init__(name)
1460
        self._forward_implemented = True
1461
        self._backward_implemented = True
1462

1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
1479 1480
            backward_op.input("Y")[0]
        )
1481
        process_mesh = dist_attr.process_mesh
1482
        processes = process_mesh.process_ids
1483
        # col parallel: matmul + allreduce
1484 1485
        if backward_op.attr("trans_y"):
            Y_var_dim_mapping.reverse()
1486 1487 1488 1489 1490 1491 1492 1493
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
1494 1495 1496
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1497

1498 1499 1500
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
1513 1514
                parallel_axis=parallel_axis,
            )
1515
            comm_op_cost_list = build_comm_costs_from_descs(
1516 1517 1518 1519 1520 1521
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
1522 1523 1524 1525 1526
            res.append(comm_op_cost_list)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1527 1528
            backward_op.input("X")[0]
        )
1529
        mesh_shape = process_mesh.shape
1530
        batch_size_axis = var_dim_mapping[0]
1531 1532 1533 1534 1535
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1536 1537 1538
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1539 1540 1541
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1542 1543 1544 1545 1546
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
        # TODO: trans shape if trans_x or trans_y is True
1547 1548 1549
        comp_desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1550
        processes = dist_op.dist_attr.process_mesh.process_ids
1551 1552 1553
        comp_cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster
        )
1554 1555 1556 1557

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1558 1559
            serial_op.input("Y")[0]
        )[-1]
1560 1561 1562 1563 1564 1565 1566 1567 1568
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1569 1570
            parallel_axis=parallel_axis,
        )
1571
        comm_op_cost_list = build_comm_costs_from_descs(
1572 1573
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1574 1575 1576 1577

        res_cost = [comm_op_cost_list, comp_cost_mapping]
        return res_cost

1578 1579 1580
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1581 1582
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1583
        x_dims_mapping = copy.deepcopy(
1584 1585
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1586
        y_dims_mapping = copy.deepcopy(
1587 1588
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1589 1590 1591
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1592 1593
        if is_dim_shard(x_dims_mapping[-1]):
            return False
1594
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
1595 1596
            y_dims_mapping[-1]
        ):
1597 1598 1599 1600 1601 1602
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1603 1604 1605
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1606 1607 1608 1609 1610 1611 1612 1613 1614
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1615
    def is_auto_compatible(self, dist_op):
1616 1617 1618
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1619
            return False
1620
        if not _is_auto_compatible_for_matmul(dist_op):
1621 1622 1623
            return False
        return True

1624
    def update_dims_mapping(self, dist_op):
1625
        changed = False
1626
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1627 1628 1629 1630
        if dim_changed:
            changed = True
        return changed

1631 1632 1633 1634 1635 1636
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1637
        dist_op_context = ctx.dist_op_context
1638 1639 1640 1641
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1642
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1643 1644 1645
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1646 1647

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1648
        if rank_id not in op_dist_attr.process_mesh.process_ids:
1649 1650 1651
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1652

1653
        # check validation of inputs / outputs
1654 1655
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
1656 1657
                input_name
            )
1658 1659 1660 1661 1662
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
1663 1664
                output_name
            )
1665 1666 1667
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
1668 1669
                output_name
            )
1670

Z
zhaoyingli 已提交
1671
        X_var = main_block._var_recursive(kwargs['X'][0])
1672
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
1673
        Out_var = main_block._var_recursive(kwargs['Out'][0])
1674 1675
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
1676 1677 1678

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
1679 1680
            Weight_var.name
        )[-1]
1681 1682
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
1683 1684 1685 1686 1687 1688 1689
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
1690 1691
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
1692 1693

        parallel_axis = matmul_col_dim_mapping
1694 1695 1696
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1697 1698
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
1699 1700 1701 1702 1703
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
1704 1705 1706
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
1707 1708 1709 1710 1711
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
1712 1713 1714
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1715

1716
        intermediate_var_0 = main_block.create_var(
1717 1718 1719
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
1720 1721 1722 1723
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
1724 1725
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
1726
        # set intermediate_var_0's dist_attr with X_var's dist_attr
1727 1728 1729
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
1730 1731

        check_variable_and_dtype(
1732 1733 1734 1735 1736
            X_var,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            '_c_identity',
        )
1737 1738 1739 1740 1741 1742 1743 1744
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
1745
                OP_ROLE_KEY: src_op.attr('op_role'),
1746 1747
            },
        )
Z
zhaoyingli 已提交
1748 1749
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
1750

1751 1752 1753 1754 1755 1756 1757 1758 1759
        check_variable_and_dtype(
            intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64'],
            'linear',
        )
1760
        attrs = {
1761 1762
            'trans_x': trans_x,
            'trans_y': trans_y,
1763
            OP_ROLE_KEY: src_op.attr('op_role'),
1764
        }
1765
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
1766 1767 1768 1769 1770 1771
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': Out_var},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1772 1773 1774 1775 1776 1777 1778
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1779
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1780 1781 1782 1783 1784
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
1785 1786 1787 1788 1789
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1790 1791
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
1792 1793 1794
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1795 1796 1797 1798 1799
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1800
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1801 1802 1803 1804
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
1805 1806
                    input_varname
                )
Z
zhaoyingli 已提交
1807
                assert input_dist_attr is not None, "dist_attr is {}".format(
1808 1809
                    op_dist_attr
                )
1810
                matmulv2_op_dist_attr.set_input_dist_attr(
1811 1812
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
1813
            else:
Z
zhaoyingli 已提交
1814
                input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
1815
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
1816 1817
                    input_var
                )
1818
                matmulv2_op_dist_attr.set_input_dist_attr(
1819 1820
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
1821 1822 1823
        for output_varname in matmul_v2_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
1824 1825 1826 1827 1828
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
Z
zhaoyingli 已提交
1829
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
1830 1831

        # init param sync
1832
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1833 1834 1835
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1836 1837 1838 1839

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1840 1841 1842 1843 1844


# RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl):
    def __init__(self, name):
1845
        super().__init__(name)
1846
        self._forward_implemented = True
1847
        self._backward_implemented = True
1848

1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
Z
zhaoyingli 已提交
1864

1865
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
1866 1867
            backward_op.input("Y")[0]
        )
1868 1869 1870 1871
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        process_mesh = dist_attr.process_mesh
1872
        processes = process_mesh.process_ids
1873 1874 1875 1876 1877 1878 1879 1880 1881
        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1882 1883
            parallel_axis=parallel_axis,
        )
1884
        comm_op_cost_list = build_comm_costs_from_descs(
1885 1886
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1887 1888 1889
        res.append(comm_op_cost_list)

        # calc comp op cost
1890 1891 1892 1893 1894 1895
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1896 1897 1898 1899 1900
        res.append(cost_mapping)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1901 1902
            backward_op.input("X")[0]
        )
1903
        mesh_shape = process_mesh.shape
1904
        batch_size_axis = var_dim_mapping[0]
1905 1906 1907 1908 1909
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1910 1911 1912
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1913 1914 1915
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1916 1917 1918 1919
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1920 1921 1922
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1923
        processes = dist_op.dist_attr.process_mesh.process_ids
1924 1925 1926
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
1927 1928 1929 1930

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1931 1932
            serial_op.input("Y")[0]
        )[-2]
1933 1934 1935 1936 1937 1938 1939 1940 1941
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1942 1943
            parallel_axis=parallel_axis,
        )
1944 1945

        comm_op_cost_list = build_comm_costs_from_descs(
1946 1947 1948 1949 1950 1951
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
1952 1953 1954 1955
        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1956 1957 1958
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1959 1960
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1961
        x_dims_mapping = copy.deepcopy(
1962 1963
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1964
        y_dims_mapping = copy.deepcopy(
1965 1966
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1967 1968 1969
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1970 1971
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1972
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
1973 1974
            y_dims_mapping[-1]
        ):
1975 1976 1977 1978 1979 1980 1981
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1982 1983 1984
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1985 1986 1987 1988 1989 1990 1991 1992 1993 1994
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1995
    def is_auto_compatible(self, dist_op):
1996 1997 1998
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1999
            return False
2000
        if not _is_auto_compatible_for_matmul(dist_op):
2001 2002 2003
            return False
        return True

2004
    def update_dims_mapping(self, dist_op):
2005
        changed = False
2006
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
2007 2008 2009 2010
        if dim_changed:
            changed = True
        return changed

2011 2012 2013 2014 2015 2016
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

2017
        dist_op_context = ctx.dist_op_context
2018 2019 2020 2021
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
2022
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2023 2024 2025
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2026 2027

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2028
        if rank_id not in op_dist_attr.process_mesh.process_ids:
2029 2030 2031
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2032

2033
        # check validation of inputs / outputs
2034 2035
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2036 2037
                input_name
            )
2038 2039 2040 2041 2042
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2043 2044
                output_name
            )
2045 2046 2047
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2048 2049
                output_name
            )
2050

Z
zhaoyingli 已提交
2051
        X_var = main_block._var_recursive(kwargs['X'][0])
2052
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2053
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2054 2055
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
2056 2057 2058

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2059 2060
            Weight_var.name
        )[-2]
2061 2062
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2063 2064 2065 2066 2067 2068 2069
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2070 2071
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
2072 2073

        parallel_axis = matmul_row_dim_mapping
2074 2075 2076
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2077 2078
        group = new_process_group(group_ranks)

2079 2080 2081 2082 2083 2084
        check_variable_and_dtype(
            X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
        )
2085
        attrs = {
2086 2087
            'trans_x': trans_x,
            'trans_y': trans_y,
2088
            OP_ROLE_KEY: src_op.attr('op_role'),
2089
        }
2090
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
2091 2092 2093 2094 2095 2096

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2097 2098 2099
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
2100

2101
        intermediate_var_0 = main_block.create_var(
2102 2103 2104
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
2105 2106 2107 2108 2109 2110
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
2111 2112
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
2113
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
2114 2115 2116
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
2117

2118 2119 2120 2121 2122 2123
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
2124 2125
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
2126 2127 2128 2129 2130 2131 2132 2133

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
2134
                'use_model_parallel': True,
2135 2136 2137
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
2138 2139 2140 2141 2142 2143 2144
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2145
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2146 2147 2148 2149
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
2150 2151 2152 2153 2154
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
2155 2156 2157
        output_varname = matmul_v2_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
2158 2159 2160 2161 2162
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
2163 2164 2165 2166 2167
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2168
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2169 2170
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
2171
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
2172 2173
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
2174 2175 2176
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
2177 2178 2179
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
2180 2181 2182 2183 2184 2185 2186 2187
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
2188 2189

        # init param sync
2190
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
2191 2192 2193
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2194 2195 2196 2197

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
2198 2199


2200
# ReplicateParallel
2201
class DistributedMatmulV2Impl2(DistributedOperatorImpl):
2202
    def __init__(self, name):
2203
        super().__init__(name)
2204

2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        process_mesh = dist_attr.process_mesh

        # calc comp op cost
2222 2223 2224
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2225
        processes = process_mesh.process_ids
2226 2227 2228
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
2229 2230 2231 2232
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2233 2234
            backward_op.input("X")[0]
        )
2235
        mesh_shape = process_mesh.shape
2236
        batch_size_axis = var_dim_mapping[0]
2237 2238 2239 2240 2241
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2242 2243 2244
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2245 2246 2247
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2248 2249 2250 2251 2252

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2253 2254 2255
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2256
        processes = dist_op.dist_attr.process_mesh.process_ids
2257 2258 2259
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
2260 2261 2262 2263 2264

        res_cost = [cost_mapping]

        return res_cost

2265 2266 2267
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2268 2269 2270 2271 2272 2273 2274
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
2275
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
2276 2277
            x_dims_mapping[-2]
        ):
2278 2279 2280 2281
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
2282
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
2283 2284
            y_dims_mapping[-2]
        ):
2285 2286 2287
            return False
        return True

2288 2289 2290 2291 2292
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2293 2294 2295 2296 2297
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
2298
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
2299 2300
            out_dims_mapping[-2]
        ):
2301 2302 2303 2304
            return False

        return True

2305
    def is_auto_compatible(self, dist_op):
2306 2307 2308
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2309 2310
            return False

2311
        if not _is_auto_compatible_for_matmul(dist_op):
2312 2313 2314 2315
            return False

        return True

2316
    def update_dims_mapping(self, dist_op):
2317
        changed = False
2318
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
2319 2320 2321 2322
        if dim_changed:
            changed = True
        return changed

2323 2324 2325 2326
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

2327 2328 2329 2330
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

2331 2332

register_distributed_operator_impl(
2333 2334 2335 2336 2337 2338 2339 2340
    "matmul_v2", DistributedMatmulV2Impl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")
)
2341 2342 2343 2344


class DistributedMul(DistributedOperatorImplContainer):
    def __init__(self, op_type):
2345
        super().__init__(op_type)
2346 2347 2348 2349 2350 2351 2352 2353


register_distributed_operator_impl_container(DistributedMul("mul"))


# ColumnParallel
class DistributedMulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
2354
        super().__init__(name)
2355 2356 2357
        self._forward_implemented = True
        self._backward_implemented = True

2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
2374 2375
            backward_op.input("Y")[0]
        )
2376 2377 2378 2379 2380 2381 2382 2383 2384
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
2385 2386 2387
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2388
        process_mesh = dist_attr.process_mesh
2389
        processes = process_mesh.process_ids
2390 2391 2392
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
2405 2406
                parallel_axis=parallel_axis,
            )
2407
            comm_op_cost_list = build_comm_costs_from_descs(
2408 2409 2410 2411 2412 2413
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
2414 2415 2416 2417
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2418 2419
            backward_op.input("X")[0]
        )
2420
        mesh_shape = process_mesh.shape
2421
        batch_size_axis = var_dim_mapping[0]
2422 2423 2424 2425 2426
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2427 2428 2429
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2430 2431 2432
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2433 2434 2435 2436
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2437 2438 2439
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2440
        processes = dist_op.dist_attr.process_mesh.process_ids
2441 2442 2443
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2444 2445 2446 2447

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
2448 2449
            serial_op.input("Y")[0]
        )[-1]
2450 2451 2452 2453 2454 2455 2456 2457
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2458 2459
            parallel_axis=parallel_axis,
        )
2460 2461

        comm_op_cost_list = build_comm_costs_from_descs(
2462 2463
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2464 2465 2466 2467
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

2468 2469 2470 2471 2472 2473 2474 2475 2476
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_shard(x_dims_mapping[-1]):
            return False
2477
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
2478 2479
            y_dims_mapping[-1]
        ):
2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
2499 2500 2501
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2528 2529 2530
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2531 2532

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2533
        if rank_id not in op_dist_attr.process_mesh.process_ids:
2534 2535 2536
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2537 2538 2539 2540

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2541 2542
                input_name
            )
2543 2544 2545 2546 2547
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2548 2549
                output_name
            )
2550 2551 2552
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2553 2554
                output_name
            )
2555

Z
zhaoyingli 已提交
2556
        X_var = main_block._var_recursive(kwargs['X'][0])
2557
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2558
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2559 2560 2561

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
2562 2563 2564 2565 2566 2567 2568
            Weight_var.name
        )[-1]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
2569 2570
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
2571 2572

        parallel_axis = matmul_col_dim_mapping
2573 2574 2575
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2576 2577 2578 2579 2580 2581 2582
        group = new_process_group(group_ranks)

        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
2583 2584 2585
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
2586 2587 2588 2589 2590
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2591 2592 2593
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
2594 2595

        intermediate_var_0 = main_block.create_var(
2596 2597 2598
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
2599 2600 2601 2602
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
2603 2604
            stop_gradient=X_var.stop_gradient,
        )
2605
        # set intermediate_var_0's dist_attr with X_var's dist_attr
2606 2607 2608
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
2609 2610

        check_variable_and_dtype(
2611 2612 2613 2614 2615
            X_var,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            '_c_identity',
        )
2616 2617 2618 2619 2620 2621 2622 2623
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
2624 2625 2626
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
2627 2628 2629
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)

2630 2631 2632 2633 2634 2635 2636 2637 2638
        check_variable_and_dtype(
            intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64'],
            'linear',
        )
2639 2640 2641
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
2642
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
2643
            OP_ROLE_KEY: src_op.attr('op_role'),
2644
        }
2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656
        inputs = {'X': intermediate_var_0, 'Y': Weight_var}

        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            if var_name == "X":
                var = X_var
            else:
                var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
2657 2658 2659
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
2660 2661 2662
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

2663 2664 2665
        mul_op = main_block.append_op(
            type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
2666 2667 2668
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

2669 2670 2671 2672 2673
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

2674 2675 2676 2677 2678 2679 2680 2681 2682 2683
        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
2684 2685 2686 2687 2688
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
2689 2690
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
2691 2692 2693
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
2694 2695 2696 2697 2698 2699 2700 2701 2702 2703
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
2704 2705
                    input_varname
                )
2706
                assert input_dist_attr is not None, "dist_attr is {}".format(
2707 2708
                    op_dist_attr
                )
2709
                matmulv2_op_dist_attr.set_input_dist_attr(
2710 2711
                    input_varname, input_dist_attr
                )
2712
            else:
Z
zhaoyingli 已提交
2713
                input_var = main_block._var_recursive(input_varname)
2714
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
2715 2716
                    input_var
                )
2717
                matmulv2_op_dist_attr.set_input_dist_attr(
2718 2719
                    input_varname, tensor_dist_attr
                )
2720 2721 2722
        for output_varname in mul_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
2723 2724 2725 2726 2727
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
2728 2729 2730 2731
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
2732 2733 2734
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2735 2736 2737 2738 2739 2740 2741 2742 2743

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# RowParallel
class DistributedMulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
2744
        super().__init__(name)
2745 2746 2747
        self._forward_implemented = True
        self._backward_implemented = True

2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        process_mesh = dist_attr.process_mesh
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
2765 2766
            backward_op.input("Y")[0]
        )
2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2779 2780
            parallel_axis=parallel_axis,
        )
2781
        processes = process_mesh.process_ids
2782
        comm_op_cost_list = build_comm_costs_from_descs(
2783 2784
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2785 2786 2787
        res.append(comm_op_cost_list)

        # calc comp op cost
2788 2789 2790 2791 2792 2793
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2794 2795 2796 2797
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2798 2799
            backward_op.input("X")[0]
        )
2800
        mesh_shape = process_mesh.shape
2801
        batch_size_axis = var_dim_mapping[0]
2802 2803 2804 2805 2806
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2807 2808 2809
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2810 2811 2812
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2813 2814 2815 2816
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2817 2818 2819
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2820
        processes = dist_op.dist_attr.process_mesh.process_ids
2821 2822 2823
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2824 2825 2826 2827

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
2828 2829
            serial_op.input("Y")[0]
        )[-2]
2830 2831 2832 2833 2834 2835 2836 2837 2838
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2839 2840
            parallel_axis=parallel_axis,
        )
2841 2842 2843

        # print("dist_matmul.py dist_op: ", dist_op)
        comm_op_cost_list = build_comm_costs_from_descs(
2844 2845 2846 2847 2848 2849
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
2850 2851 2852 2853 2854

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

2855 2856 2857 2858 2859 2860 2861 2862 2863
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
2864
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
2865 2866
            y_dims_mapping[-1]
        ):
2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
2888 2889 2890
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2917 2918 2919
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2920 2921

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2922
        if rank_id not in op_dist_attr.process_mesh.process_ids:
2923 2924 2925
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2926 2927 2928 2929

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2930 2931
                input_name
            )
2932 2933 2934 2935 2936
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2937 2938
                output_name
            )
2939 2940 2941
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2942 2943
                output_name
            )
2944

Z
zhaoyingli 已提交
2945
        X_var = main_block._var_recursive(kwargs['X'][0])
2946
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2947
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2948 2949 2950

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2951 2952 2953 2954 2955 2956 2957
            Weight_var.name
        )[-2]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2958 2959
        process_mesh_shape = op_dist_attr.process_mesh.shape
        process_mesh_group = op_dist_attr.process_mesh.process_ids
2960 2961

        parallel_axis = matmul_row_dim_mapping
2962 2963 2964
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2965 2966
        group = new_process_group(group_ranks)

2967 2968 2969 2970 2971 2972
        check_variable_and_dtype(
            X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
        )
2973 2974 2975
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
2976
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
2977
            OP_ROLE_KEY: src_op.attr('op_role'),
2978 2979 2980 2981 2982 2983 2984 2985
        }
        inputs = {'X': X_var, 'Y': Weight_var}

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2986 2987 2988
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
2989 2990

        intermediate_var_0 = main_block.create_var(
2991 2992 2993
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
2994 2995 2996 2997 2998 2999
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
3000 3001
            need_check_feed=Out_var.desc.need_check_feed(),
        )
3002
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
3003 3004 3005
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
3006

3007 3008 3009 3010 3011 3012 3013
        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
3014 3015 3016
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
3017 3018 3019
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

3020 3021 3022 3023 3024 3025
        mul_op = main_block.append_op(
            type='mul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
3026

3027 3028 3029
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)

3030 3031 3032 3033 3034
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

3035 3036 3037 3038 3039 3040 3041
        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
3042
                'use_model_parallel': True,
3043 3044 3045
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
3046

3047 3048 3049 3050 3051 3052 3053 3054 3055 3056 3057 3058
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
3059 3060 3061 3062 3063
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
3064 3065 3066
        output_varname = mul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
3067 3068 3069 3070 3071
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
3072 3073 3074 3075 3076 3077 3078 3079
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
3080
            input_var = main_block._var_recursive(input_varname)
3081 3082
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
3083 3084 3085
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
3086 3087 3088
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
3089 3090 3091 3092 3093 3094 3095 3096
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
3097 3098 3099

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
3100 3101 3102
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
3103 3104 3105 3106 3107 3108 3109 3110 3111

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# ReplicateParallel
class DistributedMulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
3112
        super().__init__(name)
3113

3114 3115 3116 3117 3118 3119 3120 3121 3122 3123 3124 3125 3126 3127 3128 3129
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block

        # calc comp op cost
3130 3131 3132
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3133
        process_mesh = dist_attr.process_mesh
3134
        processes = process_mesh.process_ids
3135 3136 3137
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
3138 3139 3140 3141
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
3142 3143
            backward_op.input("X")[0]
        )
3144
        mesh_shape = process_mesh.shape
3145
        batch_size_axis = var_dim_mapping[0]
3146 3147 3148 3149 3150
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
3151 3152 3153
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
3154 3155 3156
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
3157 3158 3159 3160 3161

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
3162 3163 3164
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3165
        processes = dist_op.dist_attr.process_mesh.process_ids
3166 3167 3168
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
3169 3170 3171 3172

        res_cost = [cost_mapping]
        return res_cost

3173 3174 3175 3176 3177 3178 3179 3180 3181 3182
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
3183
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
3184 3185
            x_dims_mapping[-2]
        ):
3186 3187 3188
            return False
        if is_dim_shard(y_dims_mapping[-1]):
            return False
3189
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
3190 3191
            y_dims_mapping[-2]
        ):
3192 3193 3194 3195 3196 3197 3198 3199 3200 3201 3202 3203 3204
            return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
3205
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
3206 3207
            out_dims_mapping[-2]
        ):
3208 3209 3210 3211 3212
            return False

        return True

    def is_auto_compatible(self, dist_op):
3213 3214 3215
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
3216 3217 3218 3219 3220 3221 3222 3223 3224 3225 3226 3227 3228 3229 3230 3231 3232 3233 3234 3235 3236 3237 3238
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


3239 3240 3241
register_distributed_operator_impl(
    "mul", DistributedMulImpl0("column_parallel")
)
3242
register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel"))
3243 3244 3245
register_distributed_operator_impl(
    "mul", DistributedMulImpl2("replicate_parallel")
)