dist_matmul.py 117.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
import copy
C
caozhou 已提交
16

Z
zhaoyingli 已提交
17
from .common import infer_shape
18
from .common import DistributedOperatorImplContainer
19
from .common import DistributedOperatorImpl
20
from .common import register_distributed_operator_impl_container
21
from .common import register_distributed_operator_impl
22
from .common import gradient_synchronization
23
from .common import is_parameter_related, set_comm_op_dist_attr_for_program
24 25 26 27 28
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
29
from ..utils import set_dist_op_desc_original_id
30
from ..dist_attribute import OperatorDistributedAttribute
31 32
from paddle.fluid import core, unique_name
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
33
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
34
from ..process_group import new_process_group
35
from ..utils import _get_comm_group, _get_corresponding_rank
36
from .dist_default import DistributedDefaultImpl0
37 38 39 40 41
from ..cost import (
    build_comp_desc_from_dist_op,
    build_comm_desc_from_dist_op,
    build_dp_costs,
)
C
caozhou 已提交
42
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs
43
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost
C
caozhou 已提交
44
from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
45 46 47 48
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
    AllreduceSumOpCost,
    IdentityOpCost,
)
49 50


51 52
def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping):
    if trans_x:
53 54 55 56
        x_dims_mapping[-1], x_dims_mapping[-2] = (
            x_dims_mapping[-2],
            x_dims_mapping[-1],
        )
57
    if trans_y:
58 59 60 61
        y_dims_mapping[-1], y_dims_mapping[-2] = (
            y_dims_mapping[-2],
            y_dims_mapping[-1],
        )
62 63


64
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
65
    dist_op_desc = block.append_op(type='nop').desc
66
    dist_op_desc.copy_from(src_op.desc)
67
    set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
68 69 70 71 72 73 74 75 76 77
    for input_name in src_op.desc.input_names():
        assert input_name in kwargs
        dist_op_desc.set_input(input_name, kwargs[input_name])
    for output_name in src_op.desc.output_names():
        assert input_name in kwargs
        dist_op_desc.set_output(output_name, kwargs[output_name])

    return dist_op_desc


78
def _update_dims_mapping_for_matmul(dist_op):
79
    changed = False
80 81
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
82 83 84
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
C
caozhou 已提交
85 86 87 88 89 90 91 92
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
93 94 95 96 97 98 99 100 101
    x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
    y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
    out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
C
caozhou 已提交
102
        assert trans_x is False
103
        x_dims_mapping.insert(0, -1)
C
caozhou 已提交
104
        out_dims_mapping.insert(out_dims_mapping_len - 1, 0)
105
    if y_dims_mapping_len == 1:
C
caozhou 已提交
106
        assert trans_y is False
107
        y_dims_mapping.insert(1, -1)
C
caozhou 已提交
108
        out_dims_mapping.insert(out_dims_mapping_len, 0)
109

110 111
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)

C
caozhou 已提交
112 113 114
    new_x_dims_mapping_len = len(x_dims_mapping)
    new_y_dims_mapping_len = len(y_dims_mapping)
    new_out_dims_mapping_len = len(out_dims_mapping)
115
    # Deal with dim > 2 and take care of broadcasting
C
caozhou 已提交
116
    if new_out_dims_mapping_len > 2:
117 118 119 120
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

C
caozhou 已提交
121
        for i in range(new_out_dims_mapping_len - new_x_dims_mapping_len):
122
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
123
        for i in range(new_x_dims_mapping_len - 2):
124 125
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

C
caozhou 已提交
126
        for i in range(new_out_dims_mapping_len - new_y_dims_mapping_len):
127
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
C
caozhou 已提交
128
        for i in range(new_y_dims_mapping_len - 2):
129 130
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

C
caozhou 已提交
131
        for i in range(new_out_dims_mapping_len - 2):
132 133
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

134 135 136 137 138 139 140
        compatible_dims_mapping = compute_compatible_dims_mapping(
            [
                broadcast_x_dims_mapping,
                broadcast_y_dims_mapping,
                broadcast_out_dims_mapping,
            ]
        )
141
        if compatible_dims_mapping is None:
142 143 144
            trans_x_y_dims_mapping(
                trans_x, trans_y, x_dims_mapping, y_dims_mapping
            )
145
            return False
146

C
caozhou 已提交
147 148
        for i in range(new_x_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_x_dims_mapping_len)
149 150 151 152
            if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                x_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
153 154
        for i in range(new_y_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - new_y_dims_mapping_len)
155 156 157 158
            if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                y_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

C
caozhou 已提交
159
        for i in range(new_out_dims_mapping_len - 2):
160 161 162 163
            if out_dims_mapping[i] != compatible_dims_mapping[i]:
                out_dims_mapping[i] = compatible_dims_mapping[i]
                changed = True

164
    # The following which uses negative index can be work
165 166
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
    dim_changed = compute_compatible_and_update_dim_mapping(
167 168
        [x_dims_mapping, y_dims_mapping], [-1, -2]
    )
169 170 171 172
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
173 174
        [x_dims_mapping, out_dims_mapping], [-2, -2]
    )
175 176 177 178
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
179 180
        [y_dims_mapping, out_dims_mapping], [-1, -1]
    )
181 182 183
    if dim_changed:
        changed = True

184
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
C
caozhou 已提交
185

186
    # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
187 188
    if x_dims_mapping_len == 1:
        x_dims_mapping.pop(0)
C
caozhou 已提交
189
        out_dims_mapping.pop(out_dims_mapping_len - 1)
190 191
    if y_dims_mapping_len == 1:
        y_dims_mapping.pop(1)
C
caozhou 已提交
192
        out_dims_mapping.pop(out_dims_mapping_len)
193 194 195 196 197 198 199 200

    assert len(x_dims_mapping) == x_dims_mapping_len
    assert len(y_dims_mapping) == y_dims_mapping_len
    assert len(out_dims_mapping) == out_dims_mapping_len

    return changed


201 202 203 204 205 206
def _is_auto_compatible_for_matmul(dist_op):
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
207 208 209 210 211 212 213 214 215
    trans_x = None
    trans_y = None
    if op_desc.type() == "matmul_v2":
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
    elif op_desc.type() == "matmul":
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')

216 217 218 219
    # Deep copy these dims_mappings for keeping them unchanged.
    x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
    y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
    out_dims_mapping = copy.deepcopy(
220 221
        op_dist_attr.get_output_dims_mapping(out_name)
    )
222 223 224 225 226 227 228 229 230 231
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
        x_dims_mapping.insert(0, -1)
    if y_dims_mapping_len == 1:
        y_dims_mapping.insert(1, -1)

232
    trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
233

234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
    # Deal with dim > 2 and take care of broadcasting
    if out_dims_mapping_len > 2:
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

        for i in range(out_dims_mapping_len - x_dims_mapping_len):
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
        for i in range(x_dims_mapping_len - 2):
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

        for i in range(out_dims_mapping_len - y_dims_mapping_len):
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
        for i in range(y_dims_mapping_len - 2):
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

        for i in range(out_dims_mapping_len - 2):
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

253 254 255
        is_same = (broadcast_x_dims_mapping == broadcast_y_dims_mapping) and (
            broadcast_x_dims_mapping == broadcast_out_dims_mapping
        )
256 257 258 259 260
        if not is_same:
            return False

    # The following which uses negative index can be work
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
261
    is_same = x_dims_mapping[-1] == y_dims_mapping[-2]
262 263 264
    if not is_same:
        return False

265
    is_same = x_dims_mapping[-2] == out_dims_mapping[-2]
266 267 268
    if not is_same:
        return False

269
    is_same = y_dims_mapping[-1] == out_dims_mapping[-1]
270 271 272 273 274 275
    if not is_same:
        return False

    return True


276 277 278 279
def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):

    # by now the backward function only insert the gradient allreduce for dist op itself

280
    dist_op_context = ctx.dist_op_context
281 282 283
    main_block = dist_op_context.work_block
    backward_op = dist_op_context.cur_src_op
    rank_id = dist_op_context.rank_id
284
    dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
285 286 287
    assert (
        dist_attr is not None
    ), "backward op [{}] don't have dist attribute !".format(str(backward_op))
288 289

    # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
290 291
    if rank_id not in dist_attr.process_mesh.processes:
        rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
292 293 294 295 296 297

    assert 'Y' in kwargs, "input [{}] is not given".format('Y')
    assert 'X' in kwargs, "input [{}] is not given".format('X')
    assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
    assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
    assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
298 299 300
    assert (
        len(kwargs['Y']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
301
        kwargs['Y']
302 303 304 305
    )
    assert (
        len(kwargs['X']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
306
        kwargs['X']
307 308 309 310 311 312 313 314 315
    )
    assert (
        len(kwargs['Out@GRAD']) == 1
    ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
        kwargs['Out']
    )
    assert (
        len(kwargs['Y@GRAD']) == 1
    ), "row_parallel_embedding output Ids take 1 variable but got {}".format(
316
        kwargs['Y@GRAD']
317
    )
318

Z
zhaoyingli 已提交
319
    X_var = main_block._var_recursive(kwargs['X'][0])
320
    Y_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
321 322
    Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0])
    Y_grad = main_block._var_recursive(kwargs['Y@GRAD'][0])
323

J
JZ-LIANG 已提交
324 325 326
    assert not is_parameter_related(
        X_var.name, main_block
    ), "left operand(X) [{}] of dist matmul should not be parameter".format(
327 328
        X_var.name
    )
329

330
    X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
331 332 333
    Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
    process_mesh_shape = dist_attr.process_mesh.topology
    process_mesh_group = dist_attr.process_mesh.processes
334 335 336 337 338 339 340 341 342 343 344 345 346

    trans_x = None
    trans_y = None
    if backward_op.desc.type() == "matmul_v2_grad":
        trans_x = backward_op.desc.attr('trans_x')
        trans_y = backward_op.desc.attr('trans_y')
    elif backward_op.desc.type() == "matmul_grad":
        trans_x = backward_op.desc.attr('transpose_X')
        trans_y = backward_op.desc.attr('transpose_Y')

    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

347 348 349 350
    # assert len(
    #     Y_var_dim_mapping
    # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
    #     Y_var.name, Y_var_dim_mapping)
351 352 353 354 355 356
    Y_var_partitioned = False
    for dim in Y_var_dim_mapping:
        if dim >= 0 and process_mesh_shape[dim] > 0:
            Y_var_partitioned = True
            break

J
JZ-LIANG 已提交
357
    if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:
358 359 360 361 362 363 364

        if Y_var_dim_mapping[0] >= 0:
            # row parallel: c_identity + matmul
            assert Y_var_dim_mapping[1] < 0
            parallel_axis = Y_var_dim_mapping[0]

            check_variable_and_dtype(
365 366
                Out_grad,
                'tensor',
367
                ['float16', 'float32', 'float64', 'int32', 'int64'],
368 369
                '_c_identity',
            )
370 371

            intermediate_var_0 = main_block.create_var(
372 373 374 375
                name=unique_name.generate_with_ignorable_key(
                    ".".join(["c_identity", 'tmp'])
                )
                + "@GRAD",
376 377 378 379
                dtype=Out_grad.dtype,
                shape=Out_grad.shape,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
380 381
                stop_gradient=Out_grad.stop_gradient,
            )
382 383 384 385

            # copy X_var's dist_attr to intermediate_var_0's dist_attr
            out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
            assert out_grad_dist_attr is not None
386 387 388
            ctx.set_tensor_dist_attr_for_program(
                intermediate_var_0, out_grad_dist_attr
            )
389

390 391 392
            group_ranks = _get_comm_group(
                process_mesh_group, process_mesh_shape, parallel_axis, rank_id
            )
393 394 395 396 397 398 399 400 401 402
            group = new_process_group(group_ranks)
            c_identity_op = main_block.append_op(
                type='c_identity',
                inputs={'X': [Out_grad]},
                outputs={'Out': intermediate_var_0},
                attrs={
                    'ring_id': group.id,
                    'use_calc_stream': True,
                    'use_model_parallel': True,
                    OP_ROLE_KEY: OpRole.Backward,
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
                },
            )
            check_variable_and_dtype(
                intermediate_var_0,
                'x',
                ['float16', 'float32', 'float64'],
                'linear',
            )
            check_dtype(
                intermediate_var_0.dtype,
                'dtype',
                ['float16', 'float32', 'float64'],
                'linear',
            )
            set_comm_op_dist_attr_for_program(
                c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
            )
420 421 422 423

            new_kwargs = copy.deepcopy(kwargs)
            new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
            matmul_op_desc = copy_op_with_new_input_output(
424 425
                ctx, main_block, backward_op, **new_kwargs
            )
426 427 428 429 430 431 432 433 434 435
        else:
            # col parallel: matmul + allreduce
            assert Y_var_dim_mapping[0] < 0
            parallel_axis = Y_var_dim_mapping[1]
            new_kwargs = copy.deepcopy(kwargs)

            # NOTE (JZ-LIANG) should allow left operand be empty for matmul grad
            has_x_grad = len(kwargs['X@GRAD']) > 0
            if has_x_grad:
                assert len(kwargs['X@GRAD']) == 1
Z
zhaoyingli 已提交
436
                X_grad = main_block._var_recursive(kwargs['X@GRAD'][0])
437
                intermediate_var_0 = main_block.create_var(
438 439 440 441
                    name=unique_name.generate_with_ignorable_key(
                        ".".join(["c_identity", 'tmp'])
                    )
                    + "@GRAD",
442 443 444 445
                    dtype=X_grad.dtype,
                    shape=X_grad.shape,
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    persistable=False,
446 447
                    stop_gradient=X_grad.stop_gradient,
                )
448 449 450

                X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
                assert X_grad_dist_attr is not None
451 452 453
                ctx.set_tensor_dist_attr_for_program(
                    intermediate_var_0, X_grad_dist_attr
                )
454 455 456
                new_kwargs['X@GRAD'] = [intermediate_var_0.name]

            matmul_op_desc = copy_op_with_new_input_output(
457 458
                ctx, main_block, backward_op, **new_kwargs
            )
459 460 461

            # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
            if has_x_grad:
462 463 464 465 466 467
                group_ranks = _get_comm_group(
                    process_mesh_group,
                    process_mesh_shape,
                    parallel_axis,
                    rank_id,
                )
468 469 470 471 472 473 474 475 476
                group = new_process_group(group_ranks)
                c_allreduce_sum_op = main_block.append_op(
                    type='c_allreduce_sum',
                    inputs={'X': [intermediate_var_0.name]},
                    outputs={'Out': kwargs['X@GRAD']},
                    attrs={
                        'ring_id': group.id,
                        'use_calc_stream': True,
                        'use_model_parallel': True,
477 478 479 480 481 482 483 484 485
                        OP_ROLE_KEY: OpRole.Backward,
                    },
                )
                set_comm_op_dist_attr_for_program(
                    c_allreduce_sum_op,
                    dist_attr.process_mesh,
                    X_grad_dist_attr,
                    ctx,
                )
486 487
    else:
        # replicate
488 489 490
        matmul_op_desc = copy_op_with_new_input_output(
            ctx, main_block, backward_op, **kwargs
        )
491

492 493 494 495 496 497 498
    # data parallel gradient synchronization
    act_grad_names = [X_var.name]

    out_grad_names = []
    if is_parameter_related(Y_var.name, main_block):
        out_grad_names = [kwargs['Y@GRAD'][0]]

499 500 501
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)

502 503 504
    gradient_synchronization(
        ctx, backward_op, act_grad_names, out_grad_names, rank_id
    )
505

506 507 508 509 510
    if trans_x:
        trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
    if trans_y:
        trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)

511

512
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
513

514 515
    if Weight_var.name in dist_op_context.already_init_sync_vars:
        return
516
    assert startup_block.has_var(Weight_var.name)
517
    dist_op_context.already_init_sync_vars.add(Weight_var.name)
518
    param = startup_block.var(Weight_var.name)
519 520 521
    param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
    process_mesh = param_dist_attr.process_mesh
    dim_mapping = param_dist_attr.dims_mapping
522 523 524 525 526

    for axis, size in enumerate(process_mesh.topology):
        if size <= 1 or axis in dim_mapping:
            pass
        else:
527 528 529
            group_ranks = _get_comm_group(
                process_mesh.processes, process_mesh.topology, axis, rank_id
            )
530 531
            sync_group = new_process_group(group_ranks)

532 533 534 535 536 537 538 539 540 541 542
            startup_block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': sync_group.id,
                    'root': 0,
                    'use_calc_stream': True,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
543 544


545
class DistributedMatmul(DistributedOperatorImplContainer):
546
    def __init__(self, op_type):
547
        super().__init__(op_type)
548 549


550
register_distributed_operator_impl_container(DistributedMatmul("matmul"))
551 552 553 554 555


# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
556
        super().__init__(name)
557
        self._forward_implemented = True
558
        self._backward_implemented = True
559

C
caozhou 已提交
560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
576 577
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
578 579 580 581 582 583 584 585 586
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
587 588 589
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
590 591
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
592 593 594
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
595 596 597 598 599 600 601 602 603 604 605 606
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
607 608
                parallel_axis=parallel_axis,
            )
C
caozhou 已提交
609
            comm_op_cost_list = build_comm_costs_from_descs(
610 611 612 613 614 615
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
C
caozhou 已提交
616 617 618 619
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
620 621
            backward_op.input("X")[0]
        )
C
caozhou 已提交
622 623
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
624 625 626 627 628
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
629 630 631
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
632 633 634
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
635 636 637 638
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
639 640 641
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
642
        processes = dist_op.dist_attr.process_mesh.processes
643 644 645
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
646 647 648 649

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
650 651
            serial_op.input("Y")[0]
        )[-1]
C
caozhou 已提交
652 653 654 655 656 657 658 659
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
660 661
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
662 663

        comm_op_cost_list = build_comm_costs_from_descs(
664 665
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
666 667 668 669
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

670 671 672
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
673 674
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
675
        x_dims_mapping = copy.deepcopy(
676 677
            op_dist_attr.get_input_dims_mapping(x_name)
        )
678
        y_dims_mapping = copy.deepcopy(
679 680
            op_dist_attr.get_input_dims_mapping(y_name)
        )
681 682 683
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
684 685
        if is_dim_shard(x_dims_mapping[-1]):
            return False
686
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
687 688
            y_dims_mapping[-1]
        ):
689 690 691 692 693 694
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

695 696 697
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
698 699 700 701 702 703 704 705 706
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

707
    def is_auto_compatible(self, dist_op):
708 709 710
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
711
            return False
712
        if not _is_auto_compatible_for_matmul(dist_op):
713 714 715
            return False
        return True

716
    def update_dims_mapping(self, dist_op):
717
        changed = False
718
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
719 720 721 722
        if dim_changed:
            changed = True
        return changed

723 724 725 726 727 728
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

729
        dist_op_context = ctx.dist_op_context
730 731 732 733
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
734
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
735 736 737
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
738 739

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
740
        if rank_id not in op_dist_attr.process_mesh.processes:
741 742 743
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
744

745
        # check validation of inputs / outputs
746 747
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
748 749
                input_name
            )
750 751 752 753 754
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
755 756
                output_name
            )
757 758 759
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
760 761
                output_name
            )
762

Z
zhaoyingli 已提交
763 764 765
        X_var = main_block._var_recursive(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block._var_recursive(kwargs['Out'][0])
766 767
        trans_x = src_op.attr("transpose_X")
        trans_y = src_op.attr("transpose_Y")
768 769 770

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
771 772
            Weight_var.name
        )[-1]
773 774
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
775 776 777 778 779 780 781
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
782 783
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
784 785

        parallel_axis = matmul_col_dim_mapping
786 787 788
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
789 790
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
791 792 793 794 795
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
796 797 798
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
799 800 801 802 803
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
804 805 806
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
807

808
        intermediate_var_0 = main_block.create_var(
809 810 811
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
812 813 814 815
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
816 817
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
818
        # set intermediate_var_0's dist_attr with X_var's dist_attr
819 820 821
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
822 823

        check_variable_and_dtype(
824 825 826 827 828
            X_var,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            '_c_identity',
        )
829 830 831 832 833 834 835 836 837

        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
838 839 840
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
841 842
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
843

844 845 846 847 848 849 850 851 852
        check_variable_and_dtype(
            intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64'],
            'linear',
        )
853
        attrs = {
854 855
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
856
            'alpha': 1,
857
            OP_ROLE_KEY: src_op.attr('op_role'),
858 859
        }
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
860 861 862
        matmul_op = main_block.append_op(
            type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
Z
zhaoyingli 已提交
863 864 865 866 867 868 869
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
870
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
871 872 873 874 875
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
876 877 878 879 880
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
881 882
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
883 884 885
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
886 887 888 889 890 891
        # set op dist attr
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
892
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
893 894 895 896 897
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        for input_varname in matmul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
898 899
                    input_varname
                )
Z
zhaoyingli 已提交
900
                assert input_dist_attr is not None, "dist_attr is {}".format(
901 902 903 904 905
                    op_dist_attr
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
906
            else:
Z
zhaoyingli 已提交
907
                input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
908
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
909 910 911 912 913
                    input_var
                )
                matmul_op_dist_attr.set_input_dist_attr(
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
914 915 916 917
        # output
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
        assert output_dist_attr is not None, "dist_attr is {}".format(
918 919 920 921 922
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
923 924
        # set op dist attr
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
925 926

        # init param sync
927
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
928 929 930
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
931 932 933 934

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
935

936 937 938 939

# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
940
        super().__init__(name)
941
        self._forward_implemented = True
942
        self._backward_implemented = True
943

C
caozhou 已提交
944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
960 961
            backward_op.input("Y")[0]
        )
C
caozhou 已提交
962 963 964 965 966 967 968 969 970 971 972 973
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
974 975
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
976 977 978
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
        comm_op_cost_list = build_comm_costs_from_descs(
979 980
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
C
caozhou 已提交
981 982 983
        res.append(comm_op_cost_list)

        # calc comp op cost
984 985 986 987 988 989
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
990 991 992 993
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
994 995
            backward_op.input("X")[0]
        )
C
caozhou 已提交
996 997
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
998 999 1000 1001 1002
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1003 1004 1005
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1006 1007 1008
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1009 1010 1011 1012
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1013 1014 1015
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
1016
        processes = dist_op.dist_attr.process_mesh.processes
1017 1018 1019
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1020 1021 1022 1023

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1024 1025
            serial_op.input("Y")[0]
        )[-2]
C
caozhou 已提交
1026 1027 1028 1029 1030 1031 1032 1033 1034
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1035 1036
            parallel_axis=parallel_axis,
        )
C
caozhou 已提交
1037 1038

        comm_op_cost_list = build_comm_costs_from_descs(
1039 1040 1041 1042 1043 1044
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
C
caozhou 已提交
1045 1046 1047 1048 1049

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1050 1051 1052
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1053 1054
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1055
        x_dims_mapping = copy.deepcopy(
1056 1057
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1058
        y_dims_mapping = copy.deepcopy(
1059 1060
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1061 1062 1063
        trans_x = op_desc.attr('transpose_X')
        trans_y = op_desc.attr('transpose_Y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1064 1065
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1066
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
1067 1068
            y_dims_mapping[-1]
        ):
1069 1070 1071 1072 1073 1074 1075
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1076 1077 1078
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1089
    def is_auto_compatible(self, dist_op):
1090 1091 1092
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1093
            return False
1094
        if not _is_auto_compatible_for_matmul(dist_op):
1095 1096 1097
            return False
        return True

1098
    def update_dims_mapping(self, dist_op):
1099
        changed = False
1100
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1101 1102 1103 1104
        if dim_changed:
            changed = True
        return changed

1105 1106 1107 1108 1109 1110
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1111
        dist_op_context = ctx.dist_op_context
1112 1113 1114 1115
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1116
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1117 1118 1119
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1120 1121

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1122
        if rank_id not in op_dist_attr.process_mesh.processes:
1123 1124 1125
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1126

1127
        # check validation of inputs / outputs
1128 1129
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
1130 1131
                input_name
            )
1132 1133 1134 1135 1136
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
1137 1138
                output_name
            )
1139 1140 1141
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
1142 1143
                output_name
            )
1144

Z
zhaoyingli 已提交
1145 1146 1147
        X_var = main_block._var_recursive(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block._var_recursive(kwargs['Out'][0])
1148 1149
        trans_x = src_op.attr('transpose_X')
        trans_y = src_op.attr('transpose_Y')
1150 1151 1152

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
1153 1154
            Weight_var.name
        )[-2]
1155 1156
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
1157 1158 1159 1160 1161 1162 1163
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
1164 1165
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
1166 1167

        parallel_axis = matmul_row_dim_mapping
1168 1169 1170
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1171 1172
        group = new_process_group(group_ranks)

1173 1174 1175 1176 1177 1178
        check_variable_and_dtype(
            X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
        )
1179
        attrs = {
1180 1181
            'transpose_X': trans_x,
            'transpose_Y': trans_y,
1182
            'alpha': 1,
1183
            OP_ROLE_KEY: src_op.attr('op_role'),
1184 1185
        }
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
1186 1187 1188 1189 1190 1191

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
1192 1193 1194
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1195

1196
        intermediate_var_0 = main_block.create_var(
1197 1198 1199
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
1200 1201 1202 1203 1204 1205
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
1206 1207
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
1208
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
1209 1210 1211
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
1212

1213 1214 1215 1216 1217 1218
        matmul_op = main_block.append_op(
            type='matmul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1219 1220
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
1221 1222 1223 1224 1225 1226 1227 1228

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
1229
                'use_model_parallel': True,
1230 1231 1232
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
1233 1234 1235 1236 1237 1238 1239
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1240
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1241 1242 1243 1244
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
1245 1246 1247 1248 1249
                op_dist_attr
            )
            matmul_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
1250 1251 1252
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
1253 1254 1255 1256 1257
            op_dist_attr
        )
        matmul_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
1258 1259 1260 1261 1262
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1263
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1264 1265
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
1266
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
1267 1268
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
1269 1270 1271
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
1272 1273 1274
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
1275 1276 1277 1278 1279 1280 1281 1282
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
1283 1284

        # init param sync
1285
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1286 1287 1288
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1289 1290 1291 1292

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1293

1294

1295
# ReplicateParallel
1296 1297
class DistributedMatmulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
1298
        super().__init__(name)
1299

C
caozhou 已提交
1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block

        # calc comp op cost
1316 1317 1318
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
1319 1320
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
1321 1322 1323
        cost_mapping = build_comp_costs_from_descs(
            MatmulGradOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1324 1325 1326 1327
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1328 1329
            backward_op.input("X")[0]
        )
C
caozhou 已提交
1330 1331
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
1332 1333 1334 1335 1336
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
C
caozhou 已提交
1337 1338 1339
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1340 1341 1342
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
C
caozhou 已提交
1343 1344 1345 1346 1347

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1348 1349 1350
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
C
caozhou 已提交
1351
        processes = dist_op.dist_attr.process_mesh.processes
1352 1353 1354
        cost_mapping = build_comp_costs_from_descs(
            MatmulOpCost, ctx, processes, desc_mapping, cluster
        )
C
caozhou 已提交
1355 1356 1357 1358

        res_cost = [cost_mapping]
        return res_cost

1359 1360 1361
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1362 1363 1364 1365 1366 1367 1368
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
1369
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
1370 1371
            x_dims_mapping[-2]
        ):
1372 1373 1374 1375
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
1376
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
1377 1378
            y_dims_mapping[-2]
        ):
1379 1380 1381 1382
            return False

        return True

1383 1384 1385
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1386 1387 1388 1389 1390
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
1391
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
1392 1393
            out_dims_mapping[-2]
        ):
1394 1395 1396 1397
            return False

        return True

1398
    def is_auto_compatible(self, dist_op):
1399 1400 1401
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1402 1403
            return False

1404
        if not _is_auto_compatible_for_matmul(dist_op):
1405 1406 1407 1408
            return False

        return True

1409
    def update_dims_mapping(self, dist_op):
1410
        changed = False
1411
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1412 1413 1414 1415
        if dim_changed:
            changed = True
        return changed

1416 1417 1418 1419
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

1420 1421 1422 1423
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

1424

1425 1426 1427 1428 1429 1430 1431 1432 1433
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul", DistributedMatmulImpl2("replicate_parallel")
)
1434 1435


1436
class DistributedMatmulV2(DistributedOperatorImplContainer):
1437
    def __init__(self, op_type):
1438
        super().__init__(op_type)
1439 1440


1441
register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
1442 1443


1444 1445 1446
# ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl):
    def __init__(self, name):
1447
        super().__init__(name)
1448
        self._forward_implemented = True
1449
        self._backward_implemented = True
1450

1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
1467 1468
            backward_op.input("Y")[0]
        )
1469 1470 1471
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
        # col parallel: matmul + allreduce
1472 1473
        if backward_op.attr("trans_y"):
            Y_var_dim_mapping.reverse()
1474 1475 1476 1477 1478 1479 1480 1481
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
1482 1483 1484
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1485

1486 1487 1488
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
1501 1502
                parallel_axis=parallel_axis,
            )
1503
            comm_op_cost_list = build_comm_costs_from_descs(
1504 1505 1506 1507 1508 1509
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
1510 1511 1512 1513 1514
            res.append(comm_op_cost_list)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1515 1516
            backward_op.input("X")[0]
        )
1517 1518
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
1519 1520 1521 1522 1523
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1524 1525 1526
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1527 1528 1529
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1530 1531 1532 1533 1534
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
        # TODO: trans shape if trans_x or trans_y is True
1535 1536 1537
        comp_desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1538
        processes = dist_op.dist_attr.process_mesh.processes
1539 1540 1541
        comp_cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster
        )
1542 1543 1544 1545

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1546 1547
            serial_op.input("Y")[0]
        )[-1]
1548 1549 1550 1551 1552 1553 1554 1555 1556
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1557 1558
            parallel_axis=parallel_axis,
        )
1559
        comm_op_cost_list = build_comm_costs_from_descs(
1560 1561
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1562 1563 1564 1565

        res_cost = [comm_op_cost_list, comp_cost_mapping]
        return res_cost

1566 1567 1568
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1569 1570
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1571
        x_dims_mapping = copy.deepcopy(
1572 1573
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1574
        y_dims_mapping = copy.deepcopy(
1575 1576
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1577 1578 1579
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1580 1581
        if is_dim_shard(x_dims_mapping[-1]):
            return False
1582
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
1583 1584
            y_dims_mapping[-1]
        ):
1585 1586 1587 1588 1589 1590
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1591 1592 1593
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1594 1595 1596 1597 1598 1599 1600 1601 1602
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1603
    def is_auto_compatible(self, dist_op):
1604 1605 1606
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1607
            return False
1608
        if not _is_auto_compatible_for_matmul(dist_op):
1609 1610 1611
            return False
        return True

1612
    def update_dims_mapping(self, dist_op):
1613
        changed = False
1614
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1615 1616 1617 1618
        if dim_changed:
            changed = True
        return changed

1619 1620 1621 1622 1623 1624
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1625
        dist_op_context = ctx.dist_op_context
1626 1627 1628 1629
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1630
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1631 1632 1633
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
1634 1635

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1636
        if rank_id not in op_dist_attr.process_mesh.processes:
1637 1638 1639
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
1640

1641
        # check validation of inputs / outputs
1642 1643
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
1644 1645
                input_name
            )
1646 1647 1648 1649 1650
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
1651 1652
                output_name
            )
1653 1654 1655
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
1656 1657
                output_name
            )
1658

Z
zhaoyingli 已提交
1659
        X_var = main_block._var_recursive(kwargs['X'][0])
1660
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
1661
        Out_var = main_block._var_recursive(kwargs['Out'][0])
1662 1663
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
1664 1665 1666

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
1667 1668
            Weight_var.name
        )[-1]
1669 1670
        if trans_y:
            matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
1671 1672 1673 1674 1675 1676 1677
                Weight_var.name
            )[-2]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
1678 1679
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
1680 1681

        parallel_axis = matmul_col_dim_mapping
1682 1683 1684
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
1685 1686
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
1687 1688 1689 1690 1691
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
1692 1693 1694
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
Z
zhaoyingli 已提交
1695 1696 1697 1698 1699
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
1700 1701 1702
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
1703

1704
        intermediate_var_0 = main_block.create_var(
1705 1706 1707
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
1708 1709 1710 1711
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
1712 1713
            stop_gradient=X_var.stop_gradient,
        )
Z
zhaoyingli 已提交
1714
        # set intermediate_var_0's dist_attr with X_var's dist_attr
1715 1716 1717
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
1718 1719

        check_variable_and_dtype(
1720 1721 1722 1723 1724
            X_var,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            '_c_identity',
        )
1725 1726 1727 1728 1729 1730 1731 1732
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
1733
                OP_ROLE_KEY: src_op.attr('op_role'),
1734 1735
            },
        )
Z
zhaoyingli 已提交
1736 1737
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
1738

1739 1740 1741 1742 1743 1744 1745 1746 1747
        check_variable_and_dtype(
            intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64'],
            'linear',
        )
1748
        attrs = {
1749 1750
            'trans_x': trans_x,
            'trans_y': trans_y,
1751
            OP_ROLE_KEY: src_op.attr('op_role'),
1752
        }
1753
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
1754 1755 1756 1757 1758 1759
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': Out_var},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
1760 1761 1762 1763 1764 1765 1766
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1767
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1768 1769 1770 1771 1772
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
1773 1774 1775 1776 1777
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1778 1779
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
1780 1781 1782
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
Z
zhaoyingli 已提交
1783 1784 1785 1786 1787
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1788
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1789 1790 1791 1792
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
1793 1794
                    input_varname
                )
Z
zhaoyingli 已提交
1795
                assert input_dist_attr is not None, "dist_attr is {}".format(
1796 1797
                    op_dist_attr
                )
1798
                matmulv2_op_dist_attr.set_input_dist_attr(
1799 1800
                    input_varname, input_dist_attr
                )
Z
zhaoyingli 已提交
1801
            else:
Z
zhaoyingli 已提交
1802
                input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
1803
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
1804 1805
                    input_var
                )
1806
                matmulv2_op_dist_attr.set_input_dist_attr(
1807 1808
                    input_varname, tensor_dist_attr
                )
Z
zhaoyingli 已提交
1809 1810 1811
        for output_varname in matmul_v2_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
1812 1813 1814 1815 1816
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
Z
zhaoyingli 已提交
1817
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
1818 1819

        # init param sync
1820
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1821 1822 1823
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
1824 1825 1826 1827

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1828 1829 1830 1831 1832


# RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl):
    def __init__(self, name):
1833
        super().__init__(name)
1834
        self._forward_implemented = True
1835
        self._backward_implemented = True
1836

1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
Z
zhaoyingli 已提交
1852

1853
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
1854 1855
            backward_op.input("Y")[0]
        )
1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1870 1871
            parallel_axis=parallel_axis,
        )
1872
        comm_op_cost_list = build_comm_costs_from_descs(
1873 1874
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
1875 1876 1877
        res.append(comm_op_cost_list)

        # calc comp op cost
1878 1879 1880 1881 1882 1883
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
1884 1885 1886 1887 1888
        res.append(cost_mapping)

        # need gradient allreduce
        process_mesh = dist_attr.process_mesh
        var_dim_mapping = dist_attr.get_input_dims_mapping(
1889 1890
            backward_op.input("X")[0]
        )
1891 1892
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
1893 1894 1895 1896 1897
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
1898 1899 1900
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
1901 1902 1903
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
1904 1905 1906 1907
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
1908 1909 1910
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
1911
        processes = dist_op.dist_attr.process_mesh.processes
1912 1913 1914
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
1915 1916 1917 1918

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
1919 1920
            serial_op.input("Y")[0]
        )[-2]
1921 1922 1923 1924 1925 1926 1927 1928 1929
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
1930 1931
            parallel_axis=parallel_axis,
        )
1932 1933

        comm_op_cost_list = build_comm_costs_from_descs(
1934 1935 1936 1937 1938 1939
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
1940 1941 1942 1943
        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

1944 1945 1946
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1947 1948
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
1949
        x_dims_mapping = copy.deepcopy(
1950 1951
            op_dist_attr.get_input_dims_mapping(x_name)
        )
1952
        y_dims_mapping = copy.deepcopy(
1953 1954
            op_dist_attr.get_input_dims_mapping(y_name)
        )
1955 1956 1957
        trans_x = op_desc.attr('trans_x')
        trans_y = op_desc.attr('trans_y')
        trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
1958 1959
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1960
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
1961 1962
            y_dims_mapping[-1]
        ):
1963 1964 1965 1966 1967 1968 1969
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1970 1971 1972
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1973 1974 1975 1976 1977 1978 1979 1980 1981 1982
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1983
    def is_auto_compatible(self, dist_op):
1984 1985 1986
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
1987
            return False
1988
        if not _is_auto_compatible_for_matmul(dist_op):
1989 1990 1991
            return False
        return True

1992
    def update_dims_mapping(self, dist_op):
1993
        changed = False
1994
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1995 1996 1997 1998
        if dim_changed:
            changed = True
        return changed

1999 2000 2001 2002 2003 2004
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

2005
        dist_op_context = ctx.dist_op_context
2006 2007 2008 2009
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
2010
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2011 2012 2013
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2014 2015

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
2016
        if rank_id not in op_dist_attr.process_mesh.processes:
2017 2018 2019
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2020

2021
        # check validation of inputs / outputs
2022 2023
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2024 2025
                input_name
            )
2026 2027 2028 2029 2030
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2031 2032
                output_name
            )
2033 2034 2035
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2036 2037
                output_name
            )
2038

Z
zhaoyingli 已提交
2039
        X_var = main_block._var_recursive(kwargs['X'][0])
2040
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2041
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2042 2043
        trans_x = src_op.attr('trans_x')
        trans_y = src_op.attr('trans_y')
2044 2045 2046

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2047 2048
            Weight_var.name
        )[-2]
2049 2050
        if trans_y:
            matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2051 2052 2053 2054 2055 2056 2057
                Weight_var.name
            )[-1]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2058 2059
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
2060 2061

        parallel_axis = matmul_row_dim_mapping
2062 2063 2064
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2065 2066
        group = new_process_group(group_ranks)

2067 2068 2069 2070 2071 2072
        check_variable_and_dtype(
            X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
        )
2073
        attrs = {
2074 2075
            'trans_x': trans_x,
            'trans_y': trans_y,
2076
            OP_ROLE_KEY: src_op.attr('op_role'),
2077
        }
2078
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
2079 2080 2081 2082 2083 2084

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2085 2086 2087
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
Z
zhaoyingli 已提交
2088

2089
        intermediate_var_0 = main_block.create_var(
2090 2091 2092
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
2093 2094 2095 2096 2097 2098
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
2099 2100
            need_check_feed=Out_var.desc.need_check_feed(),
        )
Z
zhaoyingli 已提交
2101
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
2102 2103 2104
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
2105

2106 2107 2108 2109 2110 2111
        matmul_v2_op = main_block.append_op(
            type='matmul_v2',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
Z
zhaoyingli 已提交
2112 2113
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
2114 2115 2116 2117 2118 2119 2120 2121

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
2122
                'use_model_parallel': True,
2123 2124 2125
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
Z
zhaoyingli 已提交
2126 2127 2128 2129 2130 2131 2132
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2133
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2134 2135 2136 2137
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
2138 2139 2140 2141 2142
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
Z
zhaoyingli 已提交
2143 2144 2145
        output_varname = matmul_v2_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
2146 2147 2148 2149 2150
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
Z
zhaoyingli 已提交
2151 2152 2153 2154 2155
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
2156
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
2157 2158
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
2159
            input_var = main_block._var_recursive(input_varname)
Z
zhaoyingli 已提交
2160 2161
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
2162 2163 2164
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
Z
zhaoyingli 已提交
2165 2166 2167
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
2168 2169 2170 2171 2172 2173 2174 2175
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
2176 2177

        # init param sync
2178
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
2179 2180 2181
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2182 2183 2184 2185

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
2186 2187


2188
# ReplicateParallel
2189
class DistributedMatmulV2Impl2(DistributedOperatorImpl):
2190
    def __init__(self, name):
2191
        super().__init__(name)
2192

2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        process_mesh = dist_attr.process_mesh

        # calc comp op cost
2210 2211 2212
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2213
        processes = process_mesh.processes
2214 2215 2216
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
        )
2217 2218 2219 2220
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2221 2222
            backward_op.input("X")[0]
        )
2223 2224
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
2225 2226 2227 2228 2229
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2230 2231 2232
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2233 2234 2235
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2236 2237 2238 2239 2240

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2241 2242 2243
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2244
        processes = dist_op.dist_attr.process_mesh.processes
2245 2246 2247
        cost_mapping = build_comp_costs_from_descs(
            MatmulV2OpCost, ctx, processes, desc_mapping, cluster
        )
2248 2249 2250 2251 2252

        res_cost = [cost_mapping]

        return res_cost

2253 2254 2255
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2256 2257 2258 2259 2260 2261 2262
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
2263
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
2264 2265
            x_dims_mapping[-2]
        ):
2266 2267 2268 2269
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
2270
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
2271 2272
            y_dims_mapping[-2]
        ):
2273 2274 2275
            return False
        return True

2276 2277 2278 2279 2280
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
2281 2282 2283 2284 2285
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
2286
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
2287 2288
            out_dims_mapping[-2]
        ):
2289 2290 2291 2292
            return False

        return True

2293
    def is_auto_compatible(self, dist_op):
2294 2295 2296
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2297 2298
            return False

2299
        if not _is_auto_compatible_for_matmul(dist_op):
2300 2301 2302 2303
            return False

        return True

2304
    def update_dims_mapping(self, dist_op):
2305
        changed = False
2306
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
2307 2308 2309 2310
        if dim_changed:
            changed = True
        return changed

2311 2312 2313 2314
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

2315 2316 2317 2318
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

2319 2320

register_distributed_operator_impl(
2321 2322 2323 2324 2325 2326 2327 2328
    "matmul_v2", DistributedMatmulV2Impl0("column_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl1("row_parallel")
)
register_distributed_operator_impl(
    "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")
)
2329 2330 2331 2332


class DistributedMul(DistributedOperatorImplContainer):
    def __init__(self, op_type):
2333
        super().__init__(op_type)
2334 2335 2336 2337 2338 2339 2340 2341


register_distributed_operator_impl_container(DistributedMul("mul"))


# ColumnParallel
class DistributedMulImpl0(DistributedOperatorImpl):
    def __init__(self, name):
2342
        super().__init__(name)
2343 2344 2345
        self._forward_implemented = True
        self._backward_implemented = True

2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
2362 2363
            backward_op.input("Y")[0]
        )
2364 2365 2366 2367 2368 2369 2370 2371 2372
        # col parallel: matmul + allreduce
        assert Y_var_dim_mapping[0] < 0
        parallel_axis = Y_var_dim_mapping[1]

        has_x_grad = len(backward_op.output("X@GRAD")) > 0
        if has_x_grad:
            assert len(backward_op.output("X@GRAD")) == 1

        # calc comp op cost
2373 2374 2375
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2376 2377
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
2378 2379 2380
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392
        res.append(cost_mapping)

        # calc comm op cost
        if has_x_grad:
            attrs = {"use_calc_stream": True, "use_model_parallel": True}
            var_names = backward_op.output("X@GRAD")
            c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
                "c_allreduce_sum",
                dist_op,
                ctx,
                var_names,
                attrs=attrs,
2393 2394
                parallel_axis=parallel_axis,
            )
2395
            comm_op_cost_list = build_comm_costs_from_descs(
2396 2397 2398 2399 2400 2401
                AllreduceSumOpCost,
                ctx,
                processes,
                c_allreduce_sum_desc_mapping,
                cluster,
            )
2402 2403 2404 2405
            res.append(comm_op_cost_list)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2406 2407
            backward_op.input("X")[0]
        )
2408 2409
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
2410 2411 2412 2413 2414
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2415 2416 2417
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2418 2419 2420
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2421 2422 2423 2424
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2425 2426 2427
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2428
        processes = dist_op.dist_attr.process_mesh.processes
2429 2430 2431
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2432 2433 2434 2435

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
2436 2437
            serial_op.input("Y")[0]
        )[-1]
2438 2439 2440 2441 2442 2443 2444 2445
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        var_names = serial_op.input("X")
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2446 2447
            parallel_axis=parallel_axis,
        )
2448 2449

        comm_op_cost_list = build_comm_costs_from_descs(
2450 2451
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2452 2453 2454 2455
        res_cost = [comm_op_cost_list, cost_mapping]

        return res_cost

2456 2457 2458 2459 2460 2461 2462 2463 2464
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_shard(x_dims_mapping[-1]):
            return False
2465
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
2466 2467
            y_dims_mapping[-1]
        ):
2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
2487 2488 2489
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2516 2517 2518
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2519 2520 2521

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
        if rank_id not in op_dist_attr.process_mesh.processes:
2522 2523 2524
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2525 2526 2527 2528

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2529 2530
                input_name
            )
2531 2532 2533 2534 2535
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2536 2537
                output_name
            )
2538 2539 2540
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2541 2542
                output_name
            )
2543

Z
zhaoyingli 已提交
2544
        X_var = main_block._var_recursive(kwargs['X'][0])
2545
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2546
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2547 2548 2549

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
2550 2551 2552 2553 2554 2555 2556
            Weight_var.name
        )[-1]
        assert (
            matmul_col_dim_mapping >= 0
        ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping
        )
2557 2558 2559 2560
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes

        parallel_axis = matmul_col_dim_mapping
2561 2562 2563
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2564 2565 2566 2567 2568 2569 2570
        group = new_process_group(group_ranks)

        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
2571 2572 2573
        ref_shape_x = infer_shape(
            main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
        )
2574 2575 2576 2577 2578
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2579 2580 2581
        ref_shape_out = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
2582 2583

        intermediate_var_0 = main_block.create_var(
2584 2585 2586
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_identity", 'tmp'])
            ),
2587 2588 2589 2590
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
2591 2592
            stop_gradient=X_var.stop_gradient,
        )
2593
        # set intermediate_var_0's dist_attr with X_var's dist_attr
2594 2595 2596
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, identity_var_dist_attr
        )
2597 2598

        check_variable_and_dtype(
2599 2600 2601 2602 2603
            X_var,
            'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'],
            '_c_identity',
        )
2604 2605 2606 2607 2608 2609 2610 2611
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
2612 2613 2614
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
2615 2616 2617
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)

2618 2619 2620 2621 2622 2623 2624 2625 2626
        check_variable_and_dtype(
            intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            intermediate_var_0.dtype,
            'dtype',
            ['float16', 'float32', 'float64'],
            'linear',
        )
2627 2628 2629
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
2630
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
2631
            OP_ROLE_KEY: src_op.attr('op_role'),
2632
        }
2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644
        inputs = {'X': intermediate_var_0, 'Y': Weight_var}

        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            if var_name == "X":
                var = X_var
            else:
                var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
2645 2646 2647
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
2648 2649 2650
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

2651 2652 2653
        mul_op = main_block.append_op(
            type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
        )
2654 2655 2656
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

2657 2658 2659 2660 2661
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

2662 2663 2664 2665 2666 2667 2668 2669 2670 2671
        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
2672 2673 2674 2675 2676
            op_dist_attr
        )
        identity_op_dist_attr.set_input_dist_attr(
            input_varname, input_dist_attr
        )
2677 2678
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
2679 2680 2681
        identity_op_dist_attr.set_output_dist_attr(
            output_varname, input_dist_attr
        )
2682 2683 2684 2685 2686 2687 2688 2689 2690 2691
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
2692 2693
                    input_varname
                )
2694
                assert input_dist_attr is not None, "dist_attr is {}".format(
2695 2696
                    op_dist_attr
                )
2697
                matmulv2_op_dist_attr.set_input_dist_attr(
2698 2699
                    input_varname, input_dist_attr
                )
2700
            else:
Z
zhaoyingli 已提交
2701
                input_var = main_block._var_recursive(input_varname)
2702
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
2703 2704
                    input_var
                )
2705
                matmulv2_op_dist_attr.set_input_dist_attr(
2706 2707
                    input_varname, tensor_dist_attr
                )
2708 2709 2710
        for output_varname in mul_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
2711 2712 2713 2714 2715
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
2716 2717 2718 2719
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
2720 2721 2722
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
2723 2724 2725 2726 2727 2728 2729 2730 2731

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# RowParallel
class DistributedMulImpl1(DistributedOperatorImpl):
    def __init__(self, name):
2732
        super().__init__(name)
2733 2734 2735
        self._forward_implemented = True
        self._backward_implemented = True

2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        # by now the backward function only insert the gradient allreduce for dist op itself
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        process_mesh = dist_attr.process_mesh
        main_block = backward_op.block
        Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
2753 2754
            backward_op.input("Y")[0]
        )
2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766
        assert Y_var_dim_mapping[1] < 0
        parallel_axis = Y_var_dim_mapping[0]

        # calc comm op cost
        var_names = [backward_op.input("Out@GRAD")[0]]
        attrs = {"use_calc_stream": True, "use_model_parallel": True}
        c_identity_desc_mapping = build_comm_desc_from_dist_op(
            "c_identity",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2767 2768
            parallel_axis=parallel_axis,
        )
2769 2770
        processes = process_mesh.processes
        comm_op_cost_list = build_comm_costs_from_descs(
2771 2772
            IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
        )
2773 2774 2775
        res.append(comm_op_cost_list)

        # calc comp op cost
2776 2777 2778 2779 2780 2781
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
2782 2783 2784 2785
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
2786 2787
            backward_op.input("X")[0]
        )
2788 2789
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
2790 2791 2792 2793 2794
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
2795 2796 2797
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
2798 2799 2800
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
2801 2802 2803 2804
        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
2805 2806 2807
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
2808
        processes = dist_op.dist_attr.process_mesh.processes
2809 2810 2811
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
2812 2813 2814 2815

        # calc comm op cost
        serial_op = dist_op.serial_op
        parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
2816 2817
            serial_op.input("Y")[0]
        )[-2]
2818 2819 2820 2821 2822 2823 2824 2825 2826
        attrs = {"use_calc_stream": True, "use_model_parallel": True}

        var_names = serial_op.output("Out")
        c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
            "c_allreduce_sum",
            dist_op,
            ctx,
            var_names,
            attrs=attrs,
2827 2828
            parallel_axis=parallel_axis,
        )
2829 2830 2831

        # print("dist_matmul.py dist_op: ", dist_op)
        comm_op_cost_list = build_comm_costs_from_descs(
2832 2833 2834 2835 2836 2837
            AllreduceSumOpCost,
            ctx,
            processes,
            c_allreduce_sum_desc_mapping,
            cluster,
        )
2838 2839 2840 2841 2842

        res_cost = [cost_mapping, comm_op_cost_list]

        return res_cost

2843 2844 2845 2846 2847 2848 2849 2850 2851
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
2852
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
2853 2854
            y_dims_mapping[-1]
        ):
2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873 2874 2875
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
2876 2877 2878
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
2879 2880 2881 2882 2883 2884 2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
2905 2906 2907
        assert (
            op_dist_attr is not None
        ), "backward op [{}] don't have dist attribute !".format(str(src_op))
2908 2909 2910

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
        if rank_id not in op_dist_attr.process_mesh.processes:
2911 2912 2913
            rank_id = _get_corresponding_rank(
                ctx, op_dist_attr.process_mesh, rank_id
            )
2914 2915 2916 2917

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
2918 2919
                input_name
            )
2920 2921 2922 2923 2924
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
2925 2926
                output_name
            )
2927 2928 2929
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
2930 2931
                output_name
            )
2932

Z
zhaoyingli 已提交
2933
        X_var = main_block._var_recursive(kwargs['X'][0])
2934
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
Z
zhaoyingli 已提交
2935
        Out_var = main_block._var_recursive(kwargs['Out'][0])
2936 2937 2938

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
2939 2940 2941 2942 2943 2944 2945
            Weight_var.name
        )[-2]
        assert (
            matmul_row_dim_mapping >= 0
        ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping
        )
2946 2947 2948 2949
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes

        parallel_axis = matmul_row_dim_mapping
2950 2951 2952
        group_ranks = _get_comm_group(
            process_mesh_group, process_mesh_shape, parallel_axis, rank_id
        )
2953 2954
        group = new_process_group(group_ranks)

2955 2956 2957 2958 2959 2960
        check_variable_and_dtype(
            X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
        check_dtype(
            X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
        )
2961 2962 2963
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
2964
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
2965
            OP_ROLE_KEY: src_op.attr('op_role'),
2966 2967 2968 2969 2970 2971 2972 2973
        }
        inputs = {'X': X_var, 'Y': Weight_var}

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
2974 2975 2976
        ref_shape = infer_shape(
            main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
        )
2977 2978

        intermediate_var_0 = main_block.create_var(
2979 2980 2981
            name=unique_name.generate_with_ignorable_key(
                ".".join(["c_allreduce_sum", 'tmp'])
            ),
2982 2983 2984 2985 2986 2987
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
2988 2989
            need_check_feed=Out_var.desc.need_check_feed(),
        )
2990
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
2991 2992 2993
        ctx.set_tensor_dist_attr_for_program(
            intermediate_var_0, out_var_dist_attr
        )
2994

2995 2996 2997 2998 2999 3000 3001
        inputs_ref_shape = {}
        inputs_original_shape = {}
        for var_name in inputs:
            var = inputs[var_name]
            inputs_original_shape[var_name] = var.shape
            input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
            input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
3002 3003 3004
            input_ref_shape = infer_shape(
                main_block, var, input_tensor_dist_attr, input_var_dist_attr
            )
3005 3006 3007
            inputs_ref_shape[var_name] = input_ref_shape
            var.desc.set_shape(input_ref_shape)

3008 3009 3010 3011 3012 3013
        mul_op = main_block.append_op(
            type='mul',
            inputs=inputs,
            outputs={'Out': intermediate_var_0},
            attrs=attrs,
        )
3014

3015 3016 3017
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)

3018 3019 3020 3021 3022
        for var_name in inputs:
            var = inputs[var_name]
            original_shape = inputs_original_shape[var_name]
            var.desc.set_shape(original_shape)

3023 3024 3025 3026 3027 3028 3029
        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
3030
                'use_model_parallel': True,
3031 3032 3033
                OP_ROLE_KEY: src_op.attr('op_role'),
            },
        )
3034

3035 3036 3037 3038 3039 3040 3041 3042 3043 3044 3045 3046
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
3047 3048 3049 3050 3051
                op_dist_attr
            )
            matmulv2_op_dist_attr.set_input_dist_attr(
                input_varname, input_dist_attr
            )
3052 3053 3054
        output_varname = mul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
3055 3056 3057 3058 3059
            op_dist_attr
        )
        matmulv2_op_dist_attr.set_output_dist_attr(
            output_varname, output_dist_attr
        )
3060 3061 3062 3063 3064 3065 3066 3067
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
Z
zhaoyingli 已提交
3068
            input_var = main_block._var_recursive(input_varname)
3069 3070
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
3071 3072 3073
            allreduce_op_dist_attr.set_input_dist_attr(
                input_varname, tensor_dist_attr
            )
3074 3075 3076
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
3077 3078 3079 3080 3081 3082 3083 3084
                op_dist_attr
            )
            allreduce_op_dist_attr.set_output_dist_attr(
                output_varname, output_dist_attr
            )
        ctx.set_op_dist_attr_for_program(
            c_allreduce_sum_op, allreduce_op_dist_attr
        )
3085 3086 3087

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
3088 3089 3090
            _init_param_sync(
                Weight_var, dist_op_context, startup_block, ctx, rank_id
            )
3091 3092 3093 3094 3095 3096 3097 3098 3099

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# ReplicateParallel
class DistributedMulImpl2(DistributedOperatorImpl):
    def __init__(self, name):
3100
        super().__init__(name)
3101

3102 3103 3104 3105 3106 3107 3108 3109 3110 3111 3112 3113 3114 3115 3116 3117
    def calc_cost(self, op_role, dist_op, ctx, cluster):
        cost = None
        if int(op_role) == int(OpRole.Forward):
            cost = self.calc_fwd_cost(dist_op, ctx, cluster)
        elif int(op_role) == int(OpRole.Backward):
            cost = self.calc_bwd_cost(dist_op, ctx, cluster)
        assert cost is not None
        return cost

    def calc_bwd_cost(self, dist_op, ctx, cluster):
        res = []
        backward_op = dist_op.serial_op
        dist_attr = dist_op.dist_attr
        main_block = backward_op.block

        # calc comp op cost
3118 3119 3120
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3121 3122
        process_mesh = dist_attr.process_mesh
        processes = process_mesh.processes
3123 3124 3125
        cost_mapping = build_comp_costs_from_descs(
            MulGradOpCost, ctx, processes, desc_mapping, cluster
        )
3126 3127 3128 3129
        res.append(cost_mapping)

        # need gradient allreduce
        var_dim_mapping = dist_attr.get_input_dims_mapping(
3130 3131
            backward_op.input("X")[0]
        )
3132 3133
        mesh_shape = process_mesh.topology
        batch_size_axis = var_dim_mapping[0]
3134 3135 3136 3137 3138
        if (
            batch_size_axis > -1
            and mesh_shape[batch_size_axis] > 1
            and is_parameter_related(backward_op.input("Y")[0], main_block)
        ):
3139 3140 3141
            parallel_axis = batch_size_axis
            attrs = {"use_calc_stream": True}
            var_names = [backward_op.output('Y@GRAD')[0]]
3142 3143 3144
            build_dp_costs(
                res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
            )
3145 3146 3147 3148 3149

        return res

    def calc_fwd_cost(self, dist_op, ctx, cluster):
        # calc comp op cost
3150 3151 3152
        desc_mapping = build_comp_desc_from_dist_op(
            dist_op=dist_op, dist_context=ctx
        )
3153
        processes = dist_op.dist_attr.process_mesh.processes
3154 3155 3156
        cost_mapping = build_comp_costs_from_descs(
            MulOpCost, ctx, processes, desc_mapping, cluster
        )
3157 3158 3159 3160

        res_cost = [cost_mapping]
        return res_cost

3161 3162 3163 3164 3165 3166 3167 3168 3169 3170
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
3171
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
3172 3173
            x_dims_mapping[-2]
        ):
3174 3175 3176
            return False
        if is_dim_shard(y_dims_mapping[-1]):
            return False
3177
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
3178 3179
            y_dims_mapping[-2]
        ):
3180 3181 3182 3183 3184 3185 3186 3187 3188 3189 3190 3191 3192
            return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
3193
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
3194 3195
            out_dims_mapping[-2]
        ):
3196 3197 3198 3199 3200
            return False

        return True

    def is_auto_compatible(self, dist_op):
3201 3202 3203
        if (not self.is_input_compatible(dist_op)) or (
            not self.is_output_compatible(dist_op)
        ):
3204 3205 3206 3207 3208 3209 3210 3211 3212 3213 3214 3215 3216 3217 3218 3219 3220 3221 3222 3223 3224 3225 3226
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


3227 3228 3229
register_distributed_operator_impl(
    "mul", DistributedMulImpl0("column_parallel")
)
3230
register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel"))
3231 3232 3233
register_distributed_operator_impl(
    "mul", DistributedMulImpl2("replicate_parallel")
)