dist_matmul.py 118.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
import copy
C
caozhou 已提交
16

Z
zhaoyingli 已提交
17
from .common import infer_shape
18
from .common import DistributedOperatorImplContainer
19
from .common import DistributedOperatorImpl
20
from .common import register_distributed_operator_impl_container
21
from .common import register_distributed_operator_impl
22
from .common import gradient_synchronization
J
JZ-LIANG 已提交
23 24 25 26 27
from .common import (
    set_comm_op_dist_attr_for_program,
    naive_copy_op_dist_attr_for_program,
    is_parameter_related,
)
28 29 30 31 32 33
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
34
from ..utils import set_dist_op_desc_original_id
35
from ..dist_attribute import OperatorDistributedAttribute
36
from paddle.fluid import core, unique_name
J
Jiabin Yang 已提交
37
from paddle.fluid.framework import _non_static_mode
38 39
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
J
JZ-LIANG 已提交
40 41 42 43 44
from paddle.distributed.fleet.meta_optimizers.common import (
    OpRole,
    OP_ROLE_KEY,
    OP_ROLE_VAR_KEY,
)
45
from ..process_group import new_process_group
46
from ..utils import _get_comm_group, _get_corresponding_rank
47
from .dist_default import DistributedDefaultImpl0
J
JZ-LIANG 已提交
48 49 50 51 52
from ..cost import (
    build_comp_desc_from_dist_op,
    build_comm_desc_from_dist_op,
    build_dp_costs,
)
C
caozhou 已提交
53
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs
54
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost
C
caozhou 已提交
55
from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
J
JZ-LIANG 已提交
56 57 58 59
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
    AllreduceSumOpCost,
    IdentityOpCost,
)
60 61


62 63
def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping):
    if trans_x:
J
JZ-LIANG 已提交
64 65 66 67
        x_dims_mapping[-1], x_dims_mapping[-2] = (
            x_dims_mapping[-2],
            x_dims_mapping[-1],
        )
68
    if trans_y:
J
JZ-LIANG 已提交
69 70 71 72
        y_dims_mapping[-1], y_dims_mapping[-2] = (
            y_dims_mapping[-2],
            y_dims_mapping[-1],
        )
73 74


75
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
76
    dist_op_desc = block.append_op(type='nop').desc
77
    dist_op_desc.copy_from(src_op.desc)
78
    set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
79 80 81 82 83 84 85 86 87 88
    for input_name in src_op.desc.input_names():
        assert input_name in kwargs
        dist_op_desc.set_input(input_name, kwargs[input_name])
    for output_name in src_op.desc.output_names():
        assert input_name in kwargs
        dist_op_desc.set_output(output_name, kwargs[output_name])

    return dist_op_desc


89
def _update_dims_mapping_for_matmul(dist_op):
90
    changed = False
91 92
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
93 94 95
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
C
caozhou 已提交
96 97 98 99 100 101 102 103
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
104 105 106 107 108 109 110 111 112
    x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
    y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
    out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
C
caozhou 已提交
113
        assert trans_x is False
114
        x_dims_mapping.insert(0, -1)
C
caozhou 已提交
115
        out_dims_mapping.insert(out_dims_mapping_len - 1, 0)
116
    if y_dims_mapping_len == 1:
C
caozhou 已提交
117
        assert trans_y is False
118
        y_dims_mapping.insert(1, -1)
C
caozhou 已提交
119
        out_dims_mapping.insert(out_dims_mapping_len, 0)
120

121 122
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)

C
caozhou 已提交
123 124 125
    new_x_dims_mapping_len = len(x_dims_mapping)
    new_y_dims_mapping_len = len(y_dims_mapping)
    new_out_dims_mapping_len = len(out_dims_mapping)
126
    # Deal with dim > 2 and take care of broadcasting
C
caozhou 已提交
127
    if new_out_dims_mapping_len > 2:
128 129 130 131
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

C
caozhou 已提交
132
        for i in range(new_out_dims_mapping_len - new_x_dims_mapping_len):
133
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
134
        for i in range(new_x_dims_mapping_len - 2):
135 136
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

C
caozhou 已提交
137
        for i in range(new_out_dims_mapping_len - new_y_dims_mapping_len):
138
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
139
        for i in range(new_y_dims_mapping_len - 2):
140 141
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

C
caozhou 已提交
142
        for i in range(new_out_dims_mapping_len - 2):
143 144
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

J
JZ-LIANG 已提交
145 146 147 148 149 150 151
        compatible_dims_mapping = compute_compatible_dims_mapping(
            [
                broadcast_x_dims_mapping,
                broadcast_y_dims_mapping,
                broadcast_out_dims_mapping,
            ]
        )
152
        if compatible_dims_mapping is None:
J
JZ-LIANG 已提交
153 154 155
            trans_x_y_dims_mapping(
                trans_x, trans_y, x_dims_mapping, y_dims_mapping
            )
156
            return False
157

C
caozhou 已提交
158 159
        for i in range(new_x_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_x_dims_mapping_len)
160 161 162 163
            if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                x_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
164 165
        for i in range(new_y_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_y_dims_mapping_len)
166 167 168 169
            if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                y_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
170
        for i in range(new_out_dims_mapping_len - 2):
171 172 173 174
            if out_dims_mapping[i] != compatible_dims_mapping[i]:
                out_dims_mapping[i] = compatible_dims_mapping[i]
                changed = True

175
    # The following which uses negative index can be work
176 177
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
    dim_changed = compute_compatible_and_update_dim_mapping(
J
JZ-LIANG 已提交
178 179
        [x_dims_mapping, y_dims_mapping], [-1, -2]
    )
180 181 182 183
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
J
JZ-LIANG 已提交
184 185
        [x_dims_mapping, out_dims_mapping], [-2, -2]
    )
186 187 188 189
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
J
JZ-LIANG 已提交
190 191
        [y_dims_mapping, out_dims_mapping], [-1, -1]
    )
192 193 194
    if dim_changed:
        changed = True

195
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
C
caozhou 已提交
196

197
    # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
198 199
    if x_dims_mapping_len == 1:
        x_dims_mapping.pop(0)
C
caozhou 已提交
200
        out_dims_mapping.pop(out_dims_mapping_len - 1)
201 202
    if y_dims_mapping_len == 1:
        y_dims_mapping.pop(1)
C
caozhou 已提交
203
        out_dims_mapping.pop(out_dims_mapping_len)
204 205 206 207 208 209 210 211

    assert len(x_dims_mapping) == x_dims_mapping_len
    assert len(y_dims_mapping) == y_dims_mapping_len
    assert len(out_dims_mapping) == out_dims_mapping_len

    return changed


212 213 214 215 216 217
def _is_auto_compatible_for_matmul(dist_op):
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
218 219 220 221 222 223 224 225 226
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')

227 228 229 230
    # Deep copy these dims_mappings for keeping them unchanged.
    x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
    y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
    out_dims_mapping = copy.deepcopy(
J
JZ-LIANG 已提交
231 232
        op_dist_attr.get_output_dims_mapping(out_name)
    )
233 234 235 236 237 238 239 240 241 242
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
        x_dims_mapping.insert(0, -1)
    if y_dims_mapping_len == 1:
        y_dims_mapping.insert(1, -1)

243
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
244

245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
    # Deal with dim > 2 and take care of broadcasting
    if out_dims_mapping_len > 2:
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

        for i in range(out_dims_mapping_len - x_dims_mapping_len):
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
        for i in range(x_dims_mapping_len - 2):
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

        for i in range(out_dims_mapping_len - y_dims_mapping_len):
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
        for i in range(y_dims_mapping_len - 2):
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

        for i in range(out_dims_mapping_len - 2):
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

J
JZ-LIANG 已提交
264 265 266
        is_same = (broadcast_x_dims_mapping == broadcast_y_dims_mapping) and (
            broadcast_x_dims_mapping == broadcast_out_dims_mapping
        )
267 268 269 270 271
        if not is_same:
            return False

    # The following which uses negative index can be work
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
J
JZ-LIANG 已提交
272
    is_same = x_dims_mapping[-1] == y_dims_mapping[-2]
273 274 275
    if not is_same:
        return False

J
JZ-LIANG 已提交
276
    is_same = x_dims_mapping[-2] == out_dims_mapping[-2]
277 278 279
    if not is_same:
        return False

J
JZ-LIANG 已提交
280
    is_same = y_dims_mapping[-1] == out_dims_mapping[-1]
281 282 283 284 285 286
    if not is_same:
        return False

    return True


287 288 289 290
def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):

    # by now the backward function only insert the gradient allreduce for dist op itself

291
    dist_op_context = ctx.dist_op_context
292 293 294
    main_block = dist_op_context.work_block
    backward_op = dist_op_context.cur_src_op
    rank_id = dist_op_context.rank_id
295
    dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
J
JZ-LIANG 已提交
296 297 298
    assert (
        dist_attr is not None
    ), "backward op [{}] don't have dist attribute !".format(str(backward_op))
299 300

    # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
301 302
    if rank_id not in dist_attr.process_mesh.processes:
        rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
303 304 305 306 307 308

    assert 'Y' in kwargs, "input [{}] is not given".format('Y')
    assert 'X' in kwargs, "input [{}] is not given".format('X')
    assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
    assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
    assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
J
JZ-LIANG 已提交
309 310 311
    assert (
        len(kwargs['Y']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
312
        kwargs['Y']
J
JZ-LIANG 已提交
313 314 315 316
    )
    assert (
        len(kwargs['X']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
317
        kwargs['X']
J
JZ-LIANG 已提交
318 319 320 321 322 323 324 325 326
    )
    assert (
        len(kwargs['Out@GRAD']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
        kwargs['Out']
    )
    assert (
        len(kwargs['Y@GRAD']) == 1
    ), "row_parallel_embedding output Ids take 1 variable but got {}".format(
327
        kwargs['Y@GRAD']
J
JZ-LIANG 已提交
328
    )
329 330

    X_var = main_block.var(kwargs['X'][0])
331
    Y_var = main_block._var_recursive(kwargs['Y'][0])
332 333 334
    Out_grad = main_block.var(kwargs['Out@GRAD'][0])
    Y_grad = main_block.var(kwargs['Y@GRAD'][0])

J
JZ-LIANG 已提交
335 336 337
    assert not is_parameter_related(
        X_var.name, main_block
    ), "left operand(X) [{}] of dist matmul should not be parameter".format(
J
JZ-LIANG 已提交
338 339
        X_var.name
    )
340

341
    X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
342 343 344
    Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
    process_mesh_shape = dist_attr.process_mesh.topology
    process_mesh_group = dist_attr.process_mesh.processes
345 346 347 348 349 350 351 352 353 354 355 356 357

    trans_x = None
    trans_y = None
    if backward_op.desc.type() == "matmul_v2_grad":
        trans_x = backward_op.desc.attr('trans_x')
        trans_y = backward_op.desc.attr('trans_y')
    elif backward_op.desc.type() == "matmul_grad":
        trans_x = backward_op.desc.attr('transpose_X')
        trans_y = backward_op.desc.attr('transpose_Y')

    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

358 359 360 361
    # assert len(
    #     Y_var_dim_mapping
    # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
    #     Y_var.name, Y_var_dim_mapping)
362 363 364 365 366 367
    Y_var_partitioned = False
    for dim in Y_var_dim_mapping:
        if dim >= 0 and process_mesh_shape[dim] > 0:
            Y_var_partitioned = True
            break

J
JZ-LIANG 已提交
368
    if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:
369 370 371 372 373 374 375

        if Y_var_dim_mapping[0] >= 0:
            # row parallel: c_identity + matmul
            assert Y_var_dim_mapping[1] < 0
            parallel_axis = Y_var_dim_mapping[0]

            check_variable_and_dtype(
J
JZ-LIANG 已提交
376 377 378 379 380
                Out_grad,
                'tensor',
                ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
                '_c_identity',
            )
381 382

            intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
383 384 385 386
                name=unique_name.generate_with_ignorable_key(
                    ".".join(["c_identity", 'tmp'])
                )
                + "@GRAD",
387 388 389 390
                dtype=Out_grad.dtype,
                shape=Out_grad.shape,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
J
JZ-LIANG 已提交
391 392
                stop_gradient=Out_grad.stop_gradient,
            )
393 394 395 396

            # copy X_var's dist_attr to intermediate_var_0's dist_attr
            out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
            assert out_grad_dist_attr is not None
J
JZ-LIANG 已提交
397 398 399
            ctx.set_tensor_dist_attr_for_program(
                intermediate_var_0, out_grad_dist_attr
            )
400

J
JZ-LIANG 已提交
401 402 403
            group_ranks = _get_comm_group(
                process_mesh_group, process_mesh_shape, parallel_axis, rank_id
            )
404 405 406 407 408 409 410 411 412 413
            group = new_process_group(group_ranks)
            c_identity_op = main_block.append_op(
                type='c_identity',
                inputs={'X': [Out_grad]},
                outputs={'Out': intermediate_var_0},
                attrs={
                    'ring_id': group.id,
                    'use_calc_stream': True,
                    'use_model_parallel': True,
                    OP_ROLE_KEY: OpRole.Backward,
J
JZ-LIANG 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
                },
            )
            check_variable_and_dtype(
                intermediate_var_0,
                'x',
                ['float16', 'float32', 'float64', 'uint16'],
                'linear',
            )
            check_dtype(
                intermediate_var_0.dtype,
                'dtype',
                ['float16', 'float32', 'float64', 'uint16'],
                'linear',
            )
            set_comm_op_dist_attr_for_program(
                c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
            )
431 432 433 434

            new_kwargs = copy.deepcopy(kwargs)
            new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
            matmul_op_desc = copy_op_with_new_input_output(
J
JZ-LIANG 已提交
435 436
                ctx, main_block, backward_op, **new_kwargs
            )
437 438 439 440 441 442 443 444 445 446 447 448
        else:
            # col parallel: matmul + allreduce
            assert Y_var_dim_mapping[0] < 0
            parallel_axis = Y_var_dim_mapping[1]
            new_kwargs = copy.deepcopy(kwargs)

            # NOTE (JZ-LIANG) should allow left operand be empty for matmul grad
            has_x_grad = len(kwargs['X@GRAD']) > 0
            if has_x_grad:
                assert len(kwargs['X@GRAD']) == 1
                X_grad = main_block.var(kwargs['X@GRAD'][0])
                intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
449 450 451 452
                    name=unique_name.generate_with_ignorable_key(
                        ".".join(["c_identity", 'tmp'])
                    )
                    + "@GRAD",
453 454 455 456
                    dtype=X_grad.dtype,
                    shape=X_grad.shape,
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    persistable=False,
J
JZ-LIANG 已提交
457 458
                    stop_gradient=X_grad.stop_gradient,
                )
459 460 461

                X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
                assert X_grad_dist_attr is not None
J
JZ-LIANG 已提交
462 463 464
                ctx.set_tensor_dist_attr_for_program(
                    intermediate_var_0, X_grad_dist_attr
                )
465 466 467
                new_kwargs['X@GRAD'] = [intermediate_var_0.name]

            matmul_op_desc = copy_op_with_new_input_output(
J
JZ-LIANG 已提交
468 469
                ctx, main_block, backward_op, **new_kwargs
            )
470 471 472

            # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
            if has_x_grad:
J
JZ-LIANG 已提交
473 474 475 476 477 478
                group_ranks = _get_comm_group(
                    process_mesh_group,
                    process_mesh_shape,
                    parallel_axis,
                    rank_id,
                )
479 480 481 482 483 484 485 486 487
                group = new_process_group(group_ranks)
                c_allreduce_sum_op = main_block.append_op(
                    type='c_allreduce_sum',
                    inputs={'X': [intermediate_var_0.name]},
                    outputs={'Out': kwargs['X@GRAD']},
                    attrs={
                        'ring_id': group.id,
                        'use_calc_stream': True,
                        'use_model_parallel': True,
J
JZ-LIANG 已提交
488 489 490 491 492 493 494 495 496
                        OP_ROLE_KEY: OpRole.Backward,
                    },
                )
                set_comm_op_dist_attr_for_program(
                    c_allreduce_sum_op,
                    dist_attr.process_mesh,
                    X_grad_dist_attr,
                    ctx,
                )
497 498
    else:
        # replicate
J
JZ-LIANG 已提交
499 500 501
        matmul_op_desc = copy_op_with_new_input_output(
            ctx, main_block, backward_op, **kwargs
        )
502

503 504 505 506 507 508 509
    # data parallel gradient synchronization
    act_grad_names = [X_var.name]

    out_grad_names = []
    if is_parameter_related(Y_var.name, main_block):
        out_grad_names = [kwargs['Y@GRAD'][0]]

510 511 512
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)

J
JZ-LIANG 已提交
513 514 515
    gradient_synchronization(
        ctx, backward_op, act_grad_names, out_grad_names, rank_id
    )
516

517 518 519 520 521
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

522

523
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
524

525 526
    if Weight_var.name in dist_op_context.already_init_sync_vars:
        return
527
    assert startup_block.has_var(Weight_var.name)
528
    dist_op_context.already_init_sync_vars.add(Weight_var.name)
529
    param = startup_block.var(Weight_var.name)
530 531 532
    param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
    process_mesh = param_dist_attr.process_mesh
    dim_mapping = param_dist_attr.dims_mapping
533 534 535 536 537

    for axis, size in enumerate(process_mesh.topology):
        if size <= 1 or axis in dim_mapping:
            pass
        else:
J
JZ-LIANG 已提交
538 539 540
            group_ranks = _get_comm_group(
                process_mesh.processes, process_mesh.topology, axis, rank_id
            )
541 542
            sync_group = new_process_group(group_ranks)

J
JZ-LIANG 已提交
543 544 545 546 547 548 549 550 551 552 553
            startup_block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': sync_group.id,
                    'root': 0,
                    'use_calc_stream': True,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
554 555


556
class DistributedMatmul(DistributedOperatorImplContainer):
557 558
    def __init__(self, op_type):
        super(DistributedMatmul, self).__init__(op_type)
559 560


561
register_distributed_operator_impl_container(DistributedMatmul("matmul"))
562 563 564 565 566


# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
567
        super(DistributedMatmulImpl0, self).__init__(name)
568
        self._forward_implemented = True
569
        self._backward_implemented = True
570

C
caozhou 已提交
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        vars = main_block.vars
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
588 589
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
590 591 592 593 594 595 596 597 598
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
J
JZ-LIANG 已提交
599 600 601
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
602 603
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
J
JZ-LIANG 已提交
604 605 606
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
607 608 609 610 611 612 613 614 615 616 617 618
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
J
JZ-LIANG 已提交
619 620
                parallel_axis=parallel_axis,
            )
C
caozhou 已提交
621
            comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
622 623 624 625 626 627
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
C
caozhou 已提交
628 629 630 631
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
632 633
            backward_op.input("X")[0]
        )
C
caozhou 已提交
634 635
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
J
JZ-LIANG 已提交
636 637 638 639 640
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
641 642 643
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
J
JZ-LIANG 已提交
644 645 646
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
647 648 649 650
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
J
JZ-LIANG 已提交
651 652 653
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
654
        processes = dist_op.dist_attr.process_mesh.processes
J
JZ-LIANG 已提交
655 656 657
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
658 659 660 661 662

        # calc comm op cost
        serial_op = dist_op.serial_op
        vars = serial_op.block.vars
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
663 664
            serial_op.input("Y")[0]
        )[-1]
C
caozhou 已提交
665 666 667 668 669 670 671 672
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
673 674
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
675 676

        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
677 678
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
679 680 681 682
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

683 684 685
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
686 687
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
688
        x_dims_mapping = copy.deepcopy(
J
JZ-LIANG 已提交
689 690
            op_dist_attr.get_input_dims_mapping(x_name)
        )
691
        y_dims_mapping = copy.deepcopy(
J
JZ-LIANG 已提交
692 693
            op_dist_attr.get_input_dims_mapping(y_name)
        )
694 695 696
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
697 698
        if is_dim_shard(x_dims_mapping[-1]):
            return False
699
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
J
JZ-LIANG 已提交
700 701
            y_dims_mapping[-1]
        ):
702 703 704 705 706 707
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

708 709 710
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
711 712 713 714 715 716 717 718 719
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

720
    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
721 722 723
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
724
            return False
725
        if not _is_auto_compatible_for_matmul(dist_op):
726 727 728
            return False
        return True

729
    def update_dims_mapping(self, dist_op):
730
        changed = False
731
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
732 733 734 735
        if dim_changed:
            changed = True
        return changed

736 737 738 739 740 741
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

742
        dist_op_context = ctx.dist_op_context
743 744 745 746
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
747
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
J
JZ-LIANG 已提交
748 749 750
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
751 752

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
753
        if rank_id not in op_dist_attr.process_mesh.processes:
J
JZ-LIANG 已提交
754 755 756
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
757

758
        # check validation of inputs / outputs
759 760
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
761 762
                input_name
            )
763 764 765 766 767
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
768 769
                output_name
            )
770 771 772
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
J
JZ-LIANG 已提交
773 774
                output_name
            )
775 776 777 778

        X_var = main_block.var(kwargs['X'][0])
        Weight_var = main_block.var(kwargs['Y'][0])
        Out_var = main_block.var(kwargs['Out'][0])
779 780
        trans_x = src_op.attr("transpose_X")
        trans_y = src_op.attr("transpose_Y")
781 782 783

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
784 785
            Weight_var.name
        )[-1]
786 787
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
788 789 790 791 792 793 794
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
795 796
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
797 798

        parallel_axis = matmul_col_dim_mapping
J
JZ-LIANG 已提交
799 800 801
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
802 803
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
804 805 806 807 808
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
J
JZ-LIANG 已提交
809 810 811
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
812 813 814 815 816
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
J
JZ-LIANG 已提交
817 818 819
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
820

821
        intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
822 823 824
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
825 826 827 828
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
J
JZ-LIANG 已提交
829 830
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
831
        # set intermediate_var_0's dist_attr with X_var's dist_attr
J
JZ-LIANG 已提交
832 833 834
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
835 836

        check_variable_and_dtype(
J
JZ-LIANG 已提交
837 838 839 840 841
            X_var,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
            '_c_identity',
        )
842 843 844 845 846 847 848 849 850

        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
J
JZ-LIANG 已提交
851 852 853
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
854 855
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
856

J
JZ-LIANG 已提交
857 858 859 860 861 862 863 864 865 866 867 868
        check_variable_and_dtype(
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
        )
869
        attrs = {
870 871
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
872
            'alpha': 1,
J
JZ-LIANG 已提交
873
            OP_ROLE_KEY: src_op.attr('op_role'),
874 875
        }
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
J
JZ-LIANG 已提交
876 877 878
        matmul_op = main_block.append_op(
            type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
Z
zhaoyingli 已提交
879 880 881 882 883 884 885
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
886
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
887 888 889 890 891
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
892 893 894 895 896
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
897 898
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
J
JZ-LIANG 已提交
899 900 901
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
902 903 904 905 906 907
        # set op dist attr
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
908
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
909 910 911 912 913
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        for input_varname in matmul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
J
JZ-LIANG 已提交
914 915
                    input_varname
                )
Z
zhaoyingli 已提交
916
                assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
917 918 919 920 921
                    op_dist_attr
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
922 923 924
            else:
                input_var = main_block.var(input_varname)
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
J
JZ-LIANG 已提交
925 926 927 928 929
                    input_var
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
930 931 932 933
        # output
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
        assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
934 935 936 937 938
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
939 940
        # set op dist attr
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
941 942

        # init param sync
943
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
J
JZ-LIANG 已提交
944 945 946
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
947 948 949 950

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
951

952 953 954 955

# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
956
        super(DistributedMatmulImpl1, self).__init__(name)
957
        self._forward_implemented = True
958
        self._backward_implemented = True
959

C
caozhou 已提交
960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        vars = main_block.vars
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
977 978
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
979 980 981 982 983 984 985 986 987 988 989 990
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
991 992
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
993 994 995
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
996 997
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
998 999 1000
        res.append(comm_op_cost_list)

        # calc comp op cost
J
JZ-LIANG 已提交
1001 1002 1003 1004 1005 1006
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1007 1008 1009 1010
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1011 1012
            backward_op.input("X")[0]
        )
C
caozhou 已提交
1013 1014
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
J
JZ-LIANG 已提交
1015 1016 1017 1018 1019
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1020 1021 1022
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
J
JZ-LIANG 已提交
1023 1024 1025
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1026 1027 1028 1029
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
J
JZ-LIANG 已提交
1030 1031 1032
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
1033
        processes = dist_op.dist_attr.process_mesh.processes
J
JZ-LIANG 已提交
1034 1035 1036
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1037 1038 1039 1040 1041 1042

        # calc comm op cost
        serial_op = dist_op.serial_op
        vars = serial_op.block.vars

        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1043 1044
            serial_op.input("Y")[0]
        )[-2]
C
caozhou 已提交
1045 1046 1047 1048 1049 1050 1051 1052 1053
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
1054 1055
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
1056 1057

        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
1058 1059 1060 1061 1062 1063
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
C
caozhou 已提交
1064 1065 1066 1067 1068

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1069 1070 1071
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1072 1073
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1074
        x_dims_mapping = copy.deepcopy(
J
JZ-LIANG 已提交
1075 1076
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1077
        y_dims_mapping = copy.deepcopy(
J
JZ-LIANG 已提交
1078 1079
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1080 1081 1082
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1083 1084
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1085
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
J
JZ-LIANG 已提交
1086 1087
            y_dims_mapping[-1]
        ):
1088 1089 1090 1091 1092 1093 1094
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1095 1096 1097
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1098 1099 1100 1101 1102 1103 1104 1105 1106 1107
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1108
    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
1109 1110 1111
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1112
            return False
1113
        if not _is_auto_compatible_for_matmul(dist_op):
1114 1115 1116
            return False
        return True

1117
    def update_dims_mapping(self, dist_op):
1118
        changed = False
1119
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1120 1121 1122 1123
        if dim_changed:
            changed = True
        return changed

1124 1125 1126 1127 1128 1129
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1130
        dist_op_context = ctx.dist_op_context
1131 1132 1133 1134
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1135
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
J
JZ-LIANG 已提交
1136 1137 1138
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1139 1140

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1141
        if rank_id not in op_dist_attr.process_mesh.processes:
J
JZ-LIANG 已提交
1142 1143 1144
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1145

1146
        # check validation of inputs / outputs
1147 1148
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
1149 1150
                input_name
            )
1151 1152 1153 1154 1155
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
1156 1157
                output_name
            )
1158 1159 1160
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
J
JZ-LIANG 已提交
1161 1162
                output_name
            )
1163 1164 1165 1166

        X_var = main_block.var(kwargs['X'][0])
        Weight_var = main_block.var(kwargs['Y'][0])
        Out_var = main_block.var(kwargs['Out'][0])
1167 1168
        trans_x = src_op.attr('transpose_X')
        trans_y = src_op.attr('transpose_Y')
1169 1170 1171

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1172 1173
            Weight_var.name
        )[-2]
1174 1175
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1176 1177 1178 1179 1180 1181 1182
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
1183 1184
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
1185 1186

        parallel_axis = matmul_row_dim_mapping
J
JZ-LIANG 已提交
1187 1188 1189
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1190 1191
        group = new_process_group(group_ranks)

J
JZ-LIANG 已提交
1192 1193 1194 1195 1196 1197 1198 1199 1200
        check_variable_and_dtype(
            X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
        )
        check_dtype(
            X_var.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
        )
1201
        attrs = {
1202 1203
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
1204
            'alpha': 1,
J
JZ-LIANG 已提交
1205
            OP_ROLE_KEY: src_op.attr('op_role'),
1206 1207
        }
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
1208 1209 1210 1211 1212 1213

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
J
JZ-LIANG 已提交
1214 1215 1216
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1217

1218
        intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
1219 1220 1221
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
1222 1223 1224 1225 1226 1227
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
J
JZ-LIANG 已提交
1228 1229
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
1230
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
J
JZ-LIANG 已提交
1231 1232 1233
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
1234

J
JZ-LIANG 已提交
1235 1236 1237 1238 1239 1240
        matmul_op = main_block.append_op(
            type='matmul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1241 1242
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
1243 1244 1245 1246 1247 1248 1249 1250

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
1251
                'use_model_parallel': True,
J
JZ-LIANG 已提交
1252 1253 1254
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
1255 1256 1257 1258 1259 1260 1261
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1262
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1263 1264 1265 1266
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
1267 1268 1269 1270 1271
                op_dist_attr
            )
            matmul_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
1272 1273 1274
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
1275 1276 1277 1278 1279
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
1280 1281 1282 1283 1284
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1285
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1286 1287 1288 1289 1290
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
            input_var = main_block.var(input_varname)
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
J
JZ-LIANG 已提交
1291 1292 1293
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
1294 1295 1296
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
1297 1298 1299 1300 1301 1302 1303 1304
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
1305 1306

        # init param sync
1307
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
J
JZ-LIANG 已提交
1308 1309 1310
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1311 1312 1313 1314

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1315

1316

1317
# ReplicateParallel
1318 1319
class DistributedMatmulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
1320
        super(DistributedMatmulImpl2, self).__init__(name)
1321

C
caozhou 已提交
1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        vars = main_block.vars

        # calc comp op cost
J
JZ-LIANG 已提交
1339 1340 1341
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
1342 1343
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
J
JZ-LIANG 已提交
1344 1345 1346
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1347 1348 1349 1350
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1351 1352
            backward_op.input("X")[0]
        )
C
caozhou 已提交
1353 1354
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
J
JZ-LIANG 已提交
1355 1356 1357 1358 1359
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1360 1361 1362
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
J
JZ-LIANG 已提交
1363 1364 1365
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1366 1367 1368 1369 1370

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
J
JZ-LIANG 已提交
1371 1372 1373
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
1374
        processes = dist_op.dist_attr.process_mesh.processes
J
JZ-LIANG 已提交
1375 1376 1377
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1378 1379 1380 1381

        res_cost = [cost_mapping]
        return res_cost

1382 1383 1384
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1385 1386 1387 1388 1389 1390 1391
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
1392
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
J
JZ-LIANG 已提交
1393 1394
            x_dims_mapping[-2]
        ):
1395 1396 1397 1398
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
1399
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
J
JZ-LIANG 已提交
1400 1401
            y_dims_mapping[-2]
        ):
1402 1403 1404 1405
            return False

        return True

1406 1407 1408
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1409 1410 1411 1412 1413
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
1414
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
J
JZ-LIANG 已提交
1415 1416
            out_dims_mapping[-2]
        ):
1417 1418 1419 1420
            return False

        return True

1421
    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
1422 1423 1424
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1425 1426
            return False

1427
        if not _is_auto_compatible_for_matmul(dist_op):
1428 1429 1430 1431
            return False

        return True

1432
    def update_dims_mapping(self, dist_op):
1433
        changed = False
1434
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1435 1436 1437 1438
        if dim_changed:
            changed = True
        return changed

1439 1440 1441 1442
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

1443 1444 1445 1446
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

1447

J
JZ-LIANG 已提交
1448 1449 1450 1451 1452 1453 1454 1455 1456
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl2("replicate_parallel")
)
1457 1458


1459
class DistributedMatmulV2(DistributedOperatorImplContainer):
1460 1461
    def __init__(self, op_type):
        super(DistributedMatmulV2, self).__init__(op_type)
1462 1463


1464
register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
1465 1466


1467 1468 1469
# ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl):
    def __init__(self, name):
1470
        super(DistributedMatmulV2Impl0, self).__init__(name)
1471
        self._forward_implemented = True
1472
        self._backward_implemented = True
1473

1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        vars = main_block.vars
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1491 1492
            backward_op.input("Y")[0]
        )
1493 1494 1495
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
        # col parallel: matmul + allreduce
1496 1497
        if backward_op.attr("trans_y"):
            Y_var_dim_mapping.reverse()
1498 1499 1500 1501 1502 1503 1504 1505
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
J
JZ-LIANG 已提交
1506 1507 1508
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1509

J
JZ-LIANG 已提交
1510 1511 1512
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
J
JZ-LIANG 已提交
1525 1526
                parallel_axis=parallel_axis,
            )
1527
            comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
1528 1529 1530 1531 1532 1533
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
1534 1535 1536 1537 1538
            res.append(comm_op_cost_list)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1539 1540
            backward_op.input("X")[0]
        )
1541 1542
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
J
JZ-LIANG 已提交
1543 1544 1545 1546 1547
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1548 1549 1550
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
J
JZ-LIANG 已提交
1551 1552 1553
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1554 1555 1556 1557 1558
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
        # TODO: trans shape if trans_x or trans_y is True
J
JZ-LIANG 已提交
1559 1560 1561
        comp_desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1562
        processes = dist_op.dist_attr.process_mesh.processes
J
JZ-LIANG 已提交
1563 1564 1565
        comp_cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster
        )
1566 1567 1568 1569 1570 1571

        # calc comm op cost
        serial_op = dist_op.serial_op
        vars = serial_op.block.vars

        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1572 1573
            serial_op.input("Y")[0]
        )[-1]
1574 1575 1576 1577 1578 1579 1580 1581 1582
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
1583 1584
            parallel_axis=parallel_axis,
        )
1585
        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
1586 1587
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1588 1589 1590 1591

        res_cost = [comm_op_cost_list, comp_cost_mapping]
        return res_cost

1592 1593 1594
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1595 1596
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1597
        x_dims_mapping = copy.deepcopy(
J
JZ-LIANG 已提交
1598 1599
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1600
        y_dims_mapping = copy.deepcopy(
J
JZ-LIANG 已提交
1601 1602
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1603 1604 1605
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1606 1607
        if is_dim_shard(x_dims_mapping[-1]):
            return False
1608
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
J
JZ-LIANG 已提交
1609 1610
            y_dims_mapping[-1]
        ):
1611 1612 1613 1614 1615 1616
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1617 1618 1619
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1620 1621 1622 1623 1624 1625 1626 1627 1628
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1629
    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
1630 1631 1632
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1633
            return False
1634
        if not _is_auto_compatible_for_matmul(dist_op):
1635 1636 1637
            return False
        return True

1638
    def update_dims_mapping(self, dist_op):
1639
        changed = False
1640
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1641 1642 1643 1644
        if dim_changed:
            changed = True
        return changed

1645 1646 1647 1648 1649 1650
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1651
        dist_op_context = ctx.dist_op_context
1652 1653 1654 1655
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1656
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
J
JZ-LIANG 已提交
1657 1658 1659
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1660 1661

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1662
        if rank_id not in op_dist_attr.process_mesh.processes:
J
JZ-LIANG 已提交
1663 1664 1665
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1666

1667
        # check validation of inputs / outputs
1668 1669
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
1670 1671
                input_name
            )
1672 1673 1674 1675 1676
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
1677 1678
                output_name
            )
1679 1680 1681
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
J
JZ-LIANG 已提交
1682 1683
                output_name
            )
1684 1685

        X_var = main_block.var(kwargs['X'][0])
1686
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
1687
        Out_var = main_block.var(kwargs['Out'][0])
1688 1689
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
1690 1691 1692

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1693 1694
            Weight_var.name
        )[-1]
1695 1696
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1697 1698 1699 1700 1701 1702 1703
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
1704 1705
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
1706 1707

        parallel_axis = matmul_col_dim_mapping
J
JZ-LIANG 已提交
1708 1709 1710
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1711 1712
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
1713 1714 1715 1716 1717
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
J
JZ-LIANG 已提交
1718 1719 1720
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
1721 1722 1723 1724 1725
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
J
JZ-LIANG 已提交
1726 1727 1728
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1729

1730
        intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
1731 1732 1733
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
1734 1735 1736 1737
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
J
JZ-LIANG 已提交
1738 1739
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
1740
        # set intermediate_var_0's dist_attr with X_var's dist_attr
J
JZ-LIANG 已提交
1741 1742 1743
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
1744 1745

        check_variable_and_dtype(
J
JZ-LIANG 已提交
1746 1747 1748 1749 1750
            X_var,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
            '_c_identity',
        )
1751 1752 1753 1754 1755 1756 1757 1758
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
1759
                OP_ROLE_KEY: src_op.attr('op_role'),
J
JZ-LIANG 已提交
1760 1761
            },
        )
Z
zhaoyingli 已提交
1762 1763
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
1764

J
JZ-LIANG 已提交
1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776
        check_variable_and_dtype(
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
        )
1777
        attrs = {
1778 1779
            'trans_x': trans_x,
            'trans_y': trans_y,
J
JZ-LIANG 已提交
1780
            OP_ROLE_KEY: src_op.attr('op_role'),
1781
        }
1782
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
J
JZ-LIANG 已提交
1783 1784 1785 1786 1787 1788
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': Out_var},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1789 1790 1791 1792 1793 1794 1795
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1796
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1797 1798 1799 1800 1801
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
1802 1803 1804 1805 1806
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1807 1808
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
J
JZ-LIANG 已提交
1809 1810 1811
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1812 1813 1814 1815 1816
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1817
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1818 1819 1820 1821
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
J
JZ-LIANG 已提交
1822 1823
                    input_varname
                )
Z
zhaoyingli 已提交
1824
                assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
1825 1826
                    op_dist_attr
                )
1827
                matmulv2_op_dist_attr.set_input_dist_attr(
J
JZ-LIANG 已提交
1828 1829
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
1830 1831 1832
            else:
                input_var = main_block.var(input_varname)
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
J
JZ-LIANG 已提交
1833 1834
                    input_var
                )
1835
                matmulv2_op_dist_attr.set_input_dist_attr(
J
JZ-LIANG 已提交
1836 1837
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
1838 1839 1840
        for output_varname in matmul_v2_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
1841 1842 1843 1844 1845
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
Z
zhaoyingli 已提交
1846
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
1847 1848

        # init param sync
1849
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
J
JZ-LIANG 已提交
1850 1851 1852
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1853 1854 1855 1856

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1857 1858 1859 1860 1861


# RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl):
    def __init__(self, name):
1862
        super(DistributedMatmulV2Impl1, self).__init__(name)
1863
        self._forward_implemented = True
1864
        self._backward_implemented = True
1865

1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        vars = main_block.vars
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1883 1884
            backward_op.input("Y")[0]
        )
1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
1899 1900
            parallel_axis=parallel_axis,
        )
1901
        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
1902 1903
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1904 1905 1906
        res.append(comm_op_cost_list)

        # calc comp op cost
J
JZ-LIANG 已提交
1907 1908 1909 1910 1911 1912
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1913 1914 1915 1916 1917
        res.append(cost_mapping)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1918 1919
            backward_op.input("X")[0]
        )
1920 1921
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
J
JZ-LIANG 已提交
1922 1923 1924 1925 1926
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1927 1928 1929
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
J
JZ-LIANG 已提交
1930 1931 1932
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1933 1934 1935 1936
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
J
JZ-LIANG 已提交
1937 1938 1939
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1940
        processes = dist_op.dist_attr.process_mesh.processes
J
JZ-LIANG 已提交
1941 1942 1943
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
1944 1945 1946 1947 1948 1949

        # calc comm op cost
        serial_op = dist_op.serial_op
        vars = serial_op.block.vars

        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
1950 1951
            serial_op.input("Y")[0]
        )[-2]
1952 1953 1954 1955 1956 1957 1958 1959 1960
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
1961 1962
            parallel_axis=parallel_axis,
        )
1963 1964

        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
1965 1966 1967 1968 1969 1970
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
1971 1972 1973 1974
        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1975 1976 1977
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1978 1979
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1980
        x_dims_mapping = copy.deepcopy(
J
JZ-LIANG 已提交
1981 1982
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1983
        y_dims_mapping = copy.deepcopy(
J
JZ-LIANG 已提交
1984 1985
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1986 1987 1988
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1989 1990
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1991
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
J
JZ-LIANG 已提交
1992 1993
            y_dims_mapping[-1]
        ):
1994 1995 1996 1997 1998 1999 2000
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

2001 2002 2003
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2004 2005 2006 2007 2008 2009 2010 2011 2012 2013
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

2014
    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
2015 2016 2017
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2018
            return False
2019
        if not _is_auto_compatible_for_matmul(dist_op):
2020 2021 2022
            return False
        return True

2023
    def update_dims_mapping(self, dist_op):
2024
        changed = False
2025
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
2026 2027 2028 2029
        if dim_changed:
            changed = True
        return changed

2030 2031 2032 2033 2034 2035
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

2036
        dist_op_context = ctx.dist_op_context
2037 2038 2039 2040
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
2041
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
J
JZ-LIANG 已提交
2042 2043 2044
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2045 2046

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2047
        if rank_id not in op_dist_attr.process_mesh.processes:
J
JZ-LIANG 已提交
2048 2049 2050
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2051

2052
        # check validation of inputs / outputs
2053 2054
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
2055 2056
                input_name
            )
2057 2058 2059 2060 2061
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
2062 2063
                output_name
            )
2064 2065 2066
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
J
JZ-LIANG 已提交
2067 2068
                output_name
            )
2069 2070

        X_var = main_block.var(kwargs['X'][0])
2071
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
2072
        Out_var = main_block.var(kwargs['Out'][0])
2073 2074
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
2075 2076 2077

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2078 2079
            Weight_var.name
        )[-2]
2080 2081
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2082 2083 2084 2085 2086 2087 2088
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2089 2090
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
2091 2092

        parallel_axis = matmul_row_dim_mapping
J
JZ-LIANG 已提交
2093 2094 2095
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2096 2097
        group = new_process_group(group_ranks)

J
JZ-LIANG 已提交
2098 2099 2100 2101 2102 2103 2104 2105 2106
        check_variable_and_dtype(
            X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
        )
        check_dtype(
            X_var.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
        )
2107
        attrs = {
2108 2109
            'trans_x': trans_x,
            'trans_y': trans_y,
J
JZ-LIANG 已提交
2110
            OP_ROLE_KEY: src_op.attr('op_role'),
2111
        }
2112
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
2113 2114 2115 2116 2117 2118

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
J
JZ-LIANG 已提交
2119 2120 2121
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
2122

2123
        intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
2124 2125 2126
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
2127 2128 2129 2130 2131 2132
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
J
JZ-LIANG 已提交
2133 2134
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
2135
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
J
JZ-LIANG 已提交
2136 2137 2138
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
2139

J
JZ-LIANG 已提交
2140 2141 2142 2143 2144 2145
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
2146 2147
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
2148 2149 2150 2151 2152 2153 2154 2155

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
2156
                'use_model_parallel': True,
J
JZ-LIANG 已提交
2157 2158 2159
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
2160 2161 2162 2163 2164 2165 2166
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2167
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2168 2169 2170 2171
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
2172 2173 2174 2175 2176
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
2177 2178 2179
        output_varname = matmul_v2_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
2180 2181 2182 2183 2184
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
2185 2186 2187 2188 2189
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2190
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2191 2192 2193 2194 2195
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
            input_var = main_block.var(input_varname)
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
J
JZ-LIANG 已提交
2196 2197 2198
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
2199 2200 2201
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
2202 2203 2204 2205 2206 2207 2208 2209
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
2210 2211

        # init param sync
2212
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
J
JZ-LIANG 已提交
2213 2214 2215
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2216 2217 2218 2219

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
2220 2221


2222
# ReplicateParallel
2223
class DistributedMatmulV2Impl2(DistributedOperatorImpl):
2224
    def __init__(self, name):
2225
        super(DistributedMatmulV2Impl2, self).__init__(name)
2226

2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        vars = main_block.vars
        process_mesh = dist_attr.process_mesh

        # calc comp op cost
J
JZ-LIANG 已提交
2245 2246 2247
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2248
        processes = process_mesh.processes
J
JZ-LIANG 已提交
2249 2250 2251
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
2252 2253 2254 2255
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2256 2257
            backward_op.input("X")[0]
        )
2258 2259
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
J
JZ-LIANG 已提交
2260 2261 2262 2263 2264
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2265 2266 2267
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
J
JZ-LIANG 已提交
2268 2269 2270
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2271 2272 2273 2274 2275

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
J
JZ-LIANG 已提交
2276 2277 2278
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2279
        processes = dist_op.dist_attr.process_mesh.processes
J
JZ-LIANG 已提交
2280 2281 2282
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
2283 2284 2285 2286 2287

        res_cost = [cost_mapping]

        return res_cost

2288 2289 2290
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2291 2292 2293 2294 2295 2296 2297
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
2298
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
J
JZ-LIANG 已提交
2299 2300
            x_dims_mapping[-2]
        ):
2301 2302 2303 2304
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
2305
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
J
JZ-LIANG 已提交
2306 2307
            y_dims_mapping[-2]
        ):
2308 2309 2310
            return False
        return True

2311 2312 2313 2314 2315
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2316 2317 2318 2319 2320
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
2321
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
J
JZ-LIANG 已提交
2322 2323
            out_dims_mapping[-2]
        ):
2324 2325 2326 2327
            return False

        return True

2328
    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
2329 2330 2331
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2332 2333
            return False

2334
        if not _is_auto_compatible_for_matmul(dist_op):
2335 2336 2337 2338
            return False

        return True

2339
    def update_dims_mapping(self, dist_op):
2340
        changed = False
2341
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
2342 2343 2344 2345
        if dim_changed:
            changed = True
        return changed

2346 2347 2348 2349
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

2350 2351 2352 2353
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

2354 2355

register_distributed_operator_impl(
J
JZ-LIANG 已提交
2356 2357 2358 2359 2360 2361 2362 2363
    "matmul_v2", DistributedMatmulV2Impl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")
)
2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380


class DistributedMul(DistributedOperatorImplContainer):
    def __init__(self, op_type):
        super(DistributedMul, self).__init__(op_type)


register_distributed_operator_impl_container(DistributedMul("mul"))


# ColumnParallel
class DistributedMulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
        super(DistributedMulImpl0, self).__init__(name)
        self._forward_implemented = True
        self._backward_implemented = True

2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        vars = main_block.vars
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2398 2399
            backward_op.input("Y")[0]
        )
2400 2401 2402 2403 2404 2405 2406 2407 2408
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
J
JZ-LIANG 已提交
2409 2410 2411
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2412 2413
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
J
JZ-LIANG 已提交
2414 2415 2416
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
J
JZ-LIANG 已提交
2429 2430
                parallel_axis=parallel_axis,
            )
2431
            comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
2432 2433 2434 2435 2436 2437
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
2438 2439 2440 2441
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2442 2443
            backward_op.input("X")[0]
        )
2444 2445
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
J
JZ-LIANG 已提交
2446 2447 2448 2449 2450
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2451 2452 2453
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
J
JZ-LIANG 已提交
2454 2455 2456
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2457 2458 2459 2460
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
J
JZ-LIANG 已提交
2461 2462 2463
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2464
        processes = dist_op.dist_attr.process_mesh.processes
J
JZ-LIANG 已提交
2465 2466 2467
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2468 2469 2470 2471 2472

        # calc comm op cost
        serial_op = dist_op.serial_op
        vars = serial_op.block.vars
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2473 2474
            serial_op.input("Y")[0]
        )[-1]
2475 2476 2477 2478 2479 2480 2481 2482
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
2483 2484
            parallel_axis=parallel_axis,
        )
2485 2486

        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
2487 2488
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2489 2490 2491 2492
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

2493 2494 2495 2496 2497 2498 2499 2500 2501
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_shard(x_dims_mapping[-1]):
            return False
2502
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
J
JZ-LIANG 已提交
2503 2504
            y_dims_mapping[-1]
        ):
2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
2524 2525 2526
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
J
JZ-LIANG 已提交
2553 2554 2555
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2556 2557 2558

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
        if rank_id not in op_dist_attr.process_mesh.processes:
J
JZ-LIANG 已提交
2559 2560 2561
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2562 2563 2564 2565

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
2566 2567
                input_name
            )
2568 2569 2570 2571 2572
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
2573 2574
                output_name
            )
2575 2576 2577
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
J
JZ-LIANG 已提交
2578 2579
                output_name
            )
2580 2581 2582 2583 2584 2585 2586

        X_var = main_block.var(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block.var(kwargs['Out'][0])

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2587 2588 2589 2590 2591 2592 2593
            Weight_var.name
        )[-1]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
2594 2595 2596 2597
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes

        parallel_axis = matmul_col_dim_mapping
J
JZ-LIANG 已提交
2598 2599 2600
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2601 2602 2603 2604 2605 2606 2607
        group = new_process_group(group_ranks)

        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
J
JZ-LIANG 已提交
2608 2609 2610
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
2611 2612 2613 2614 2615
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
J
JZ-LIANG 已提交
2616 2617 2618
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
2619 2620

        intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
2621 2622 2623
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
2624 2625 2626 2627
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
J
JZ-LIANG 已提交
2628 2629
            stop_gradient=X_var.stop_gradient,
        )
2630
        # set intermediate_var_0's dist_attr with X_var's dist_attr
J
JZ-LIANG 已提交
2631 2632 2633
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
2634 2635

        check_variable_and_dtype(
J
JZ-LIANG 已提交
2636 2637 2638 2639 2640
            X_var,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
            '_c_identity',
        )
2641 2642 2643 2644 2645 2646 2647 2648
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
J
JZ-LIANG 已提交
2649 2650 2651
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
2652 2653 2654
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)

J
JZ-LIANG 已提交
2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666
        check_variable_and_dtype(
            intermediate_var_0,
            'x',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
        )
2667 2668 2669
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
2670
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
J
JZ-LIANG 已提交
2671
            OP_ROLE_KEY: src_op.attr('op_role'),
2672
        }
2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684
        inputs = {'X': intermediate_var_0, 'Y': Weight_var}

        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            if var_name == "X":
                var = X_var
            else:
                var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
J
JZ-LIANG 已提交
2685 2686 2687
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
2688 2689 2690
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

J
JZ-LIANG 已提交
2691 2692 2693
        mul_op = main_block.append_op(
            type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
2694 2695 2696
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

2697 2698 2699 2700 2701
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

2702 2703 2704 2705 2706 2707 2708 2709 2710 2711
        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
2712 2713 2714 2715 2716
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
2717 2718
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
J
JZ-LIANG 已提交
2719 2720 2721
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
2722 2723 2724 2725 2726 2727 2728 2729 2730 2731
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
J
JZ-LIANG 已提交
2732 2733
                    input_varname
                )
2734
                assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
2735 2736
                    op_dist_attr
                )
2737
                matmulv2_op_dist_attr.set_input_dist_attr(
J
JZ-LIANG 已提交
2738 2739
                    input_varname, input_dist_attr
                )
2740 2741 2742
            else:
                input_var = main_block.var(input_varname)
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
J
JZ-LIANG 已提交
2743 2744
                    input_var
                )
2745
                matmulv2_op_dist_attr.set_input_dist_attr(
J
JZ-LIANG 已提交
2746 2747
                    input_varname, tensor_dist_attr
                )
2748 2749 2750
        for output_varname in mul_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
2751 2752 2753 2754 2755
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
2756 2757 2758 2759
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
J
JZ-LIANG 已提交
2760 2761 2762
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# RowParallel
class DistributedMulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
        super(DistributedMulImpl1, self).__init__(name)
        self._forward_implemented = True
        self._backward_implemented = True

2776 2777 2778 2779 2780 2781 2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        process_mesh = dist_attr.process_mesh
        main_block = backward_op.block
        vars = main_block.vars
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2794 2795
            backward_op.input("Y")[0]
        )
2796 2797 2798 2799 2800 2801 2802 2803 2804 2805 2806 2807
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
2808 2809
            parallel_axis=parallel_axis,
        )
2810 2811
        processes = process_mesh.processes
        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
2812 2813
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2814 2815 2816
        res.append(comm_op_cost_list)

        # calc comp op cost
J
JZ-LIANG 已提交
2817 2818 2819 2820 2821 2822
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2823 2824 2825 2826
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2827 2828
            backward_op.input("X")[0]
        )
2829 2830
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
J
JZ-LIANG 已提交
2831 2832 2833 2834 2835
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2836 2837 2838
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
J
JZ-LIANG 已提交
2839 2840 2841
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2842 2843 2844 2845
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
J
JZ-LIANG 已提交
2846 2847 2848
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2849
        processes = dist_op.dist_attr.process_mesh.processes
J
JZ-LIANG 已提交
2850 2851 2852
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2853 2854 2855 2856 2857 2858

        # calc comm op cost
        serial_op = dist_op.serial_op
        vars = serial_op.block.vars

        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2859 2860
            serial_op.input("Y")[0]
        )[-2]
2861 2862 2863 2864 2865 2866 2867 2868 2869
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
J
JZ-LIANG 已提交
2870 2871
            parallel_axis=parallel_axis,
        )
2872 2873 2874

        # print("dist_matmul.py dist_op: ", dist_op)
        comm_op_cost_list = build_comm_costs_from_descs(
J
JZ-LIANG 已提交
2875 2876 2877 2878 2879 2880
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
2881 2882 2883 2884 2885

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

2886 2887 2888 2889 2890 2891 2892 2893 2894
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
2895
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
J
JZ-LIANG 已提交
2896 2897
            y_dims_mapping[-1]
        ):
2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
2919 2920 2921
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2922 2923 2924 2925 2926 2927 2928 2929 2930 2931 2932 2933 2934 2935 2936 2937 2938 2939 2940 2941 2942 2943 2944 2945 2946 2947
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
J
JZ-LIANG 已提交
2948 2949 2950
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2951 2952 2953

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
        if rank_id not in op_dist_attr.process_mesh.processes:
J
JZ-LIANG 已提交
2954 2955 2956
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2957 2958 2959 2960

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
2961 2962
                input_name
            )
2963 2964 2965 2966 2967
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
J
JZ-LIANG 已提交
2968 2969
                output_name
            )
2970 2971 2972
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
J
JZ-LIANG 已提交
2973 2974
                output_name
            )
2975 2976 2977 2978 2979 2980 2981

        X_var = main_block.var(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block.var(kwargs['Out'][0])

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
2982 2983 2984 2985 2986 2987 2988
            Weight_var.name
        )[-2]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2989 2990 2991 2992
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes

        parallel_axis = matmul_row_dim_mapping
J
JZ-LIANG 已提交
2993 2994 2995
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2996 2997
        group = new_process_group(group_ranks)

J
JZ-LIANG 已提交
2998 2999 3000 3001 3002 3003 3004 3005 3006
        check_variable_and_dtype(
            X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
        )
        check_dtype(
            X_var.dtype,
            'dtype',
            ['float16', 'float32', 'float64', 'uint16'],
            'linear',
        )
3007 3008 3009
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
3010
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
J
JZ-LIANG 已提交
3011
            OP_ROLE_KEY: src_op.attr('op_role'),
3012 3013 3014 3015 3016 3017 3018 3019
        }
        inputs = {'X': X_var, 'Y': Weight_var}

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
J
JZ-LIANG 已提交
3020 3021 3022
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
3023 3024

        intermediate_var_0 = main_block.create_var(
J
JZ-LIANG 已提交
3025 3026 3027
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
3028 3029 3030 3031 3032 3033
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
J
JZ-LIANG 已提交
3034 3035
            need_check_feed=Out_var.desc.need_check_feed(),
        )
3036
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
J
JZ-LIANG 已提交
3037 3038 3039
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
3040

3041 3042 3043 3044 3045 3046 3047
        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
J
JZ-LIANG 已提交
3048 3049 3050
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
3051 3052 3053
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

J
JZ-LIANG 已提交
3054 3055 3056 3057 3058 3059
        mul_op = main_block.append_op(
            type='mul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
3060

3061 3062 3063
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)

3064 3065 3066 3067 3068
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

3069 3070 3071 3072 3073 3074 3075
        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
3076
                'use_model_parallel': True,
J
JZ-LIANG 已提交
3077 3078 3079
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
3080

3081 3082 3083 3084 3085 3086 3087 3088 3089 3090 3091 3092
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
3093 3094 3095 3096 3097
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
3098 3099 3100
        output_varname = mul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
3101 3102 3103 3104 3105
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
3106 3107 3108 3109 3110 3111 3112 3113 3114 3115 3116
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
            input_var = main_block.var(input_varname)
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
J
JZ-LIANG 已提交
3117 3118 3119
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
3120 3121 3122
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
J
JZ-LIANG 已提交
3123 3124 3125 3126 3127 3128 3129 3130
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
3131 3132 3133

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
J
JZ-LIANG 已提交
3134 3135 3136
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
3137 3138 3139 3140 3141 3142 3143 3144 3145 3146 3147

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# ReplicateParallel
class DistributedMulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
        super(DistributedMulImpl2, self).__init__(name)

3148 3149 3150 3151 3152 3153 3154 3155 3156 3157 3158 3159 3160 3161 3162 3163 3164
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        vars = main_block.vars

        # calc comp op cost
J
JZ-LIANG 已提交
3165 3166 3167
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3168 3169
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
J
JZ-LIANG 已提交
3170 3171 3172
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
3173 3174 3175 3176
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
J
JZ-LIANG 已提交
3177 3178
            backward_op.input("X")[0]
        )
3179 3180
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
J
JZ-LIANG 已提交
3181 3182 3183 3184 3185
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
3186 3187 3188
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
J
JZ-LIANG 已提交
3189 3190 3191
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
3192 3193 3194 3195 3196

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
J
JZ-LIANG 已提交
3197 3198 3199
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3200
        processes = dist_op.dist_attr.process_mesh.processes
J
JZ-LIANG 已提交
3201 3202 3203
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
3204 3205 3206 3207

        res_cost = [cost_mapping]
        return res_cost

3208 3209 3210 3211 3212 3213 3214 3215 3216 3217
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
3218
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
J
JZ-LIANG 已提交
3219 3220
            x_dims_mapping[-2]
        ):
3221 3222 3223
            return False
        if is_dim_shard(y_dims_mapping[-1]):
            return False
3224
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
J
JZ-LIANG 已提交
3225 3226
            y_dims_mapping[-2]
        ):
3227 3228 3229 3230 3231 3232 3233 3234 3235 3236 3237 3238 3239
            return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
3240
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
J
JZ-LIANG 已提交
3241 3242
            out_dims_mapping[-2]
        ):
3243 3244 3245 3246 3247
            return False

        return True

    def is_auto_compatible(self, dist_op):
J
JZ-LIANG 已提交
3248 3249 3250
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
3251 3252 3253 3254 3255 3256 3257 3258 3259 3260 3261 3262 3263 3264 3265 3266 3267 3268 3269 3270 3271 3272 3273
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


J
JZ-LIANG 已提交
3274 3275 3276
register_distributed_operator_impl(
    "mul", DistributedMulImpl0("column_parallel")
)
3277
register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel"))
J
JZ-LIANG 已提交
3278 3279 3280
register_distributed_operator_impl(
    "mul", DistributedMulImpl2("replicate_parallel")
)