dist_matmul.py 85.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
import copy
Z
zhaoyingli 已提交
16
from .common import infer_shape
17
from .common import DistributedOperatorImplContainer
18
from .common import DistributedOperatorImpl
19
from .common import register_distributed_operator_impl_container
20
from .common import register_distributed_operator_impl
J
JZ-LIANG 已提交
21
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
22 23 24 25 26 27
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
28
from ..utils import set_dist_op_desc_original_id
29
from ..dist_attribute import OperatorDistributedAttribute
30
from paddle.fluid import core, unique_name
J
Jiabin Yang 已提交
31
from paddle.fluid.framework import _non_static_mode
32 33
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
34
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
35
from ..process_group import new_process_group
36
from ..utils import _get_comm_group, _get_corresponding_rank
37
from .dist_default import DistributedDefaultImpl0
38 39


40
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
41
    dist_op_desc = block.append_op(type='nop').desc
42
    dist_op_desc.copy_from(src_op.desc)
43
    set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
44 45 46 47 48 49 50 51 52 53
    for input_name in src_op.desc.input_names():
        assert input_name in kwargs
        dist_op_desc.set_input(input_name, kwargs[input_name])
    for output_name in src_op.desc.output_names():
        assert input_name in kwargs
        dist_op_desc.set_output(output_name, kwargs[output_name])

    return dist_op_desc


54
def _update_dims_mapping_for_matmul(dist_op):
55
    changed = False
56 57
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
    x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
    y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
    out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
        x_dims_mapping.insert(0, -1)
    if y_dims_mapping_len == 1:
        y_dims_mapping.insert(1, -1)

74
    # Deal with dim > 2 and take care of broadcasting
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    if out_dims_mapping_len > 2:
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

        for i in range(out_dims_mapping_len - x_dims_mapping_len):
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
        for i in range(x_dims_mapping_len - 2):
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

        for i in range(out_dims_mapping_len - y_dims_mapping_len):
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
        for i in range(y_dims_mapping_len - 2):
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

        for i in range(out_dims_mapping_len - 2):
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

        compatible_dims_mapping = compute_compatible_dims_mapping([
            broadcast_x_dims_mapping, broadcast_y_dims_mapping,
            broadcast_out_dims_mapping
        ])
97 98
        if compatible_dims_mapping is None:
            return False
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116

        for i in range(x_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - x_dims_mapping_len)
            if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                x_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

        for i in range(y_dims_mapping_len - 2):
            new_idx = i + (out_dims_mapping_len - y_dims_mapping_len)
            if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
                y_dims_mapping[i] = compatible_dims_mapping[new_idx]
                changed = True

        for i in range(out_dims_mapping_len - 2):
            if out_dims_mapping[i] != compatible_dims_mapping[i]:
                out_dims_mapping[i] = compatible_dims_mapping[i]
                changed = True

117
    # The following which uses negative index can be work
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
    dim_changed = compute_compatible_and_update_dim_mapping(
        [x_dims_mapping, y_dims_mapping], [-1, -2])
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
        [x_dims_mapping, out_dims_mapping], [-2, -2])
    if dim_changed:
        changed = True

    dim_changed = compute_compatible_and_update_dim_mapping(
        [y_dims_mapping, out_dims_mapping], [-1, -1])
    if dim_changed:
        changed = True

134
    # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
135 136 137 138 139 140 141 142 143 144 145 146
    if x_dims_mapping_len == 1:
        x_dims_mapping.pop(0)
    if y_dims_mapping_len == 1:
        y_dims_mapping.pop(1)

    assert len(x_dims_mapping) == x_dims_mapping_len
    assert len(y_dims_mapping) == y_dims_mapping_len
    assert len(out_dims_mapping) == out_dims_mapping_len

    return changed


147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
def _is_auto_compatible_for_matmul(dist_op):
    op_desc = dist_op.serial_op.desc
    op_dist_attr = dist_op.dist_attr
    x_name = op_desc.input('X')[0]
    y_name = op_desc.input('Y')[0]
    out_name = op_desc.output('Out')[0]
    # Deep copy these dims_mappings for keeping them unchanged.
    x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
    y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
    out_dims_mapping = copy.deepcopy(
        op_dist_attr.get_output_dims_mapping(out_name))
    x_dims_mapping_len = len(x_dims_mapping)
    y_dims_mapping_len = len(y_dims_mapping)
    out_dims_mapping_len = len(out_dims_mapping)

    # Add dim mapping to Make sure the length dims_mapping be at least 2
    if x_dims_mapping_len == 1:
        x_dims_mapping.insert(0, -1)
    if y_dims_mapping_len == 1:
        y_dims_mapping.insert(1, -1)

168 169 170
    # NOTE: Partition is not supported if matmul op has trans.
    if op_desc.type() == "matmul_v2":
        if op_desc.attr('trans_x') or op_desc.attr('trans_y'):
171 172
            if x_dims_mapping[-2:] != [-1, -1
                                       ] or y_dims_mapping[-2:] != [-1, -1]:
173 174 175
                return False
    elif op_desc.type() == "matmul":
        if op_desc.attr('transpose_X') or op_desc.attr('transpose_Y'):
176 177
            if x_dims_mapping[-2:] != [-1, -1
                                       ] or y_dims_mapping[-2:] != [-1, -1]:
178 179
                return False

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    # Deal with dim > 2 and take care of broadcasting
    if out_dims_mapping_len > 2:
        broadcast_x_dims_mapping = []
        broadcast_y_dims_mapping = []
        broadcast_out_dims_mapping = []

        for i in range(out_dims_mapping_len - x_dims_mapping_len):
            broadcast_x_dims_mapping.append(out_dims_mapping[i])
        for i in range(x_dims_mapping_len - 2):
            broadcast_x_dims_mapping.append(x_dims_mapping[i])

        for i in range(out_dims_mapping_len - y_dims_mapping_len):
            broadcast_y_dims_mapping.append(out_dims_mapping[i])
        for i in range(y_dims_mapping_len - 2):
            broadcast_y_dims_mapping.append(y_dims_mapping[i])

        for i in range(out_dims_mapping_len - 2):
            broadcast_out_dims_mapping.append(out_dims_mapping[i])

199 200
        is_same = ((broadcast_x_dims_mapping == broadcast_y_dims_mapping)
                   and (broadcast_x_dims_mapping == broadcast_out_dims_mapping))
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        if not is_same:
            return False

    # The following which uses negative index can be work
    # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
    is_same = (x_dims_mapping[-1] == y_dims_mapping[-2])
    if not is_same:
        return False

    is_same = (x_dims_mapping[-2] == out_dims_mapping[-2])
    if not is_same:
        return False

    is_same = (y_dims_mapping[-1] == out_dims_mapping[-1])
    if not is_same:
        return False

    return True


221 222 223 224
def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):

    # by now the backward function only insert the gradient allreduce for dist op itself

225
    dist_op_context = ctx.dist_op_context
226 227 228
    main_block = dist_op_context.work_block
    backward_op = dist_op_context.cur_src_op
    rank_id = dist_op_context.rank_id
229
    dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
230 231 232 233
    assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
        str(backward_op))

    # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
234 235
    if rank_id not in dist_attr.process_mesh.processes:
        rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259

    assert 'Y' in kwargs, "input [{}] is not given".format('Y')
    assert 'X' in kwargs, "input [{}] is not given".format('X')
    assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
    assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
    assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
    assert len(
        kwargs['Y']
    ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
        kwargs['Y'])
    assert len(
        kwargs['X']
    ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
        kwargs['X'])
    assert len(
        kwargs['Out@GRAD']
    ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
        kwargs['Out'])
    assert len(
        kwargs['Y@GRAD']
    ) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
        kwargs['Y@GRAD'])

    X_var = main_block.var(kwargs['X'][0])
260
    Y_var = main_block._var_recursive(kwargs['Y'][0])
261 262 263
    Out_grad = main_block.var(kwargs['Out@GRAD'][0])
    Y_grad = main_block.var(kwargs['Y@GRAD'][0])

J
JZ-LIANG 已提交
264 265 266
    assert not is_parameter_related(
        X_var.name, main_block
    ), "left operand(X) [{}] of dist matmul should not be parameter".format(
267 268
        X_var.name)

269 270 271
    Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
    process_mesh_shape = dist_attr.process_mesh.topology
    process_mesh_group = dist_attr.process_mesh.processes
272 273 274 275
    # assert len(
    #     Y_var_dim_mapping
    # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
    #     Y_var.name, Y_var_dim_mapping)
276 277 278 279 280 281
    Y_var_partitioned = False
    for dim in Y_var_dim_mapping:
        if dim >= 0 and process_mesh_shape[dim] > 0:
            Y_var_partitioned = True
            break

J
JZ-LIANG 已提交
282
    if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308

        if Y_var_dim_mapping[0] >= 0:
            # row parallel: c_identity + matmul
            assert Y_var_dim_mapping[1] < 0
            parallel_axis = Y_var_dim_mapping[0]

            check_variable_and_dtype(
                Out_grad, 'tensor',
                ['float16', 'float32', 'float64', 'int32', 'int64'],
                '_c_identity')

            intermediate_var_0 = main_block.create_var(
                name=unique_name.generate_with_ignorable_key(".".join(
                    ["c_identity", 'tmp'])) + "@GRAD",
                dtype=Out_grad.dtype,
                shape=Out_grad.shape,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
                stop_gradient=Out_grad.stop_gradient)

            # copy X_var's dist_attr to intermediate_var_0's dist_attr
            out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
            assert out_grad_dist_attr is not None
            ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
                                                 out_grad_dist_attr)

309 310 311
            group_ranks = _get_comm_group(process_mesh_group,
                                          process_mesh_shape, parallel_axis,
                                          rank_id)
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
            group = new_process_group(group_ranks)
            c_identity_op = main_block.append_op(
                type='c_identity',
                inputs={'X': [Out_grad]},
                outputs={'Out': intermediate_var_0},
                attrs={
                    'ring_id': group.id,
                    'use_calc_stream': True,
                    'use_model_parallel': True,
                    OP_ROLE_KEY: OpRole.Backward,
                })
            check_variable_and_dtype(intermediate_var_0, 'x',
                                     ['float16', 'float32', 'float64'],
                                     'linear')
            check_dtype(intermediate_var_0.dtype, 'dtype',
                        ['float16', 'float32', 'float64'], 'linear')
328 329 330
            set_comm_op_dist_attr_for_program(c_identity_op,
                                              dist_attr.process_mesh,
                                              out_grad_dist_attr, ctx)
331 332 333 334

            new_kwargs = copy.deepcopy(kwargs)
            new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
            matmul_op_desc = copy_op_with_new_input_output(
335
                ctx, main_block, backward_op, **new_kwargs)
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
        else:
            # col parallel: matmul + allreduce
            assert Y_var_dim_mapping[0] < 0
            parallel_axis = Y_var_dim_mapping[1]
            new_kwargs = copy.deepcopy(kwargs)

            # NOTE (JZ-LIANG) should allow left operand be empty for matmul grad
            has_x_grad = len(kwargs['X@GRAD']) > 0
            if has_x_grad:
                assert len(kwargs['X@GRAD']) == 1
                X_grad = main_block.var(kwargs['X@GRAD'][0])
                intermediate_var_0 = main_block.create_var(
                    name=unique_name.generate_with_ignorable_key(".".join(
                        ["c_identity", 'tmp'])) + "@GRAD",
                    dtype=X_grad.dtype,
                    shape=X_grad.shape,
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    persistable=False,
                    stop_gradient=X_grad.stop_gradient)

                X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
                assert X_grad_dist_attr is not None
                ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
                                                     X_grad_dist_attr)
                new_kwargs['X@GRAD'] = [intermediate_var_0.name]

            matmul_op_desc = copy_op_with_new_input_output(
363
                ctx, main_block, backward_op, **new_kwargs)
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385

            # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
            if has_x_grad:
                group_ranks = _get_comm_group(process_mesh_group,
                                              process_mesh_shape, parallel_axis,
                                              rank_id)
                group = new_process_group(group_ranks)
                c_allreduce_sum_op = main_block.append_op(
                    type='c_allreduce_sum',
                    inputs={'X': [intermediate_var_0.name]},
                    outputs={'Out': kwargs['X@GRAD']},
                    attrs={
                        'ring_id': group.id,
                        'use_calc_stream': True,
                        'use_model_parallel': True,
                        OP_ROLE_KEY: OpRole.Backward
                    })
                set_comm_op_dist_attr_for_program(c_allreduce_sum_op,
                                                  dist_attr.process_mesh,
                                                  X_grad_dist_attr, ctx)
    else:
        # replicate
386 387
        matmul_op_desc = copy_op_with_new_input_output(ctx, main_block,
                                                       backward_op, **kwargs)
388 389 390 391

    # check if need gradient allreduce
    need_gradient_allreduce = False

392
    process_mesh = dist_attr.process_mesh
393 394 395 396 397
    var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name)
    mesh_shape = process_mesh.topology
    batch_size_axis = var_dim_mapping[0]
    if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
        need_gradient_allreduce = True
398
        group_ranks = _get_comm_group(process_mesh.processes,
399 400 401 402 403
                                      process_mesh.topology, batch_size_axis,
                                      rank_id)
        dp_degree = len(group_ranks)
        dp_group = new_process_group(group_ranks)

J
JZ-LIANG 已提交
404
    if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block):
405
        added_ops = []
406
        Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
407 408 409 410 411 412 413 414
        allreduce_op = main_block.append_op(type='c_allreduce_sum',
                                            inputs={'X': [Y_Grad_var]},
                                            outputs={'Out': [Y_Grad_var]},
                                            attrs={
                                                'ring_id': dp_group.id,
                                                'use_calc_stream': True,
                                                OP_ROLE_KEY: OpRole.Backward
                                            })
415 416 417 418 419 420 421 422 423 424 425 426
        added_ops.append(allreduce_op)

        if ctx.gradient_scale:
            scale_op = main_block.append_op(type='scale',
                                            inputs={'X': Y_Grad_var},
                                            outputs={'Out': Y_Grad_var},
                                            attrs={
                                                'scale': 1.0 / dp_degree,
                                                OP_ROLE_KEY: OpRole.Backward
                                            })
            added_ops.append(scale_op)

427 428
        main_block._sync_with_cpp()

429 430 431
        dims_mapping = ctx.get_tensor_dist_attr_for_program(
            Y_Grad_var).dims_mapping
        process_mesh = dist_attr.process_mesh
432
        for op in added_ops:
433 434
            op_attr = OperatorDistributedAttribute()
            op_attr.process_mesh = process_mesh
435 436
            op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping)
            op_attr.set_input_dims_mapping(Y_Grad_var.name, dims_mapping)
437
            ctx.set_op_dist_attr_for_program(op, op_attr)
438 439


440
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
441

442 443
    if Weight_var.name in dist_op_context.already_init_sync_vars:
        return
444
    assert startup_block.has_var(Weight_var.name)
445
    dist_op_context.already_init_sync_vars.add(Weight_var.name)
446
    param = startup_block.var(Weight_var.name)
447 448 449
    param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
    process_mesh = param_dist_attr.process_mesh
    dim_mapping = param_dist_attr.dims_mapping
450 451 452 453 454

    for axis, size in enumerate(process_mesh.topology):
        if size <= 1 or axis in dim_mapping:
            pass
        else:
455
            group_ranks = _get_comm_group(process_mesh.processes,
456 457 458
                                          process_mesh.topology, axis, rank_id)
            sync_group = new_process_group(group_ranks)

459 460 461 462 463 464 465 466 467
            startup_block.append_op(type='c_broadcast',
                                    inputs={'X': param},
                                    outputs={'Out': param},
                                    attrs={
                                        'ring_id': sync_group.id,
                                        'root': 0,
                                        'use_calc_stream': True,
                                        OP_ROLE_KEY: OpRole.Forward
                                    })
468 469


470
class DistributedMatmul(DistributedOperatorImplContainer):
471

472 473
    def __init__(self, op_type):
        super(DistributedMatmul, self).__init__(op_type)
474 475


476
register_distributed_operator_impl_container(DistributedMatmul("matmul"))
477 478 479 480


# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
481

482
    def __init__(self, name):
483
        super(DistributedMatmulImpl0, self).__init__(name)
484
        self._forward_implemented = True
485
        self._backward_implemented = True
486

487 488 489
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
490 491 492 493 494 495
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_shard(x_dims_mapping[-1]):
            return False
496 497
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
                y_dims_mapping[-1]):
498 499 500 501 502 503
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

504 505 506
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
507 508 509 510 511 512 513 514 515
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

516
    def is_auto_compatible(self, dist_op):
517 518
        if (not self.is_input_compatible(dist_op)) or \
            (not self.is_output_compatible(dist_op)):
519
            return False
520
        if not _is_auto_compatible_for_matmul(dist_op):
521 522 523
            return False
        return True

524
    def update_dims_mapping(self, dist_op):
525
        changed = False
526
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
527 528 529 530
        if dim_changed:
            changed = True
        return changed

531 532 533 534 535 536
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

537
        dist_op_context = ctx.dist_op_context
538 539 540 541
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
542
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
543 544 545 546
        assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
            str(src_op))

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
547 548
        if rank_id not in op_dist_attr.process_mesh.processes:
            rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
549 550
                                              rank_id)

551
        # check validation of inputs / outputs
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
                input_name)
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
                output_name)
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
                output_name)

        X_var = main_block.var(kwargs['X'][0])
        Weight_var = main_block.var(kwargs['Y'][0])
        Out_var = main_block.var(kwargs['Out'][0])

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
572
            Weight_var.name)[-1]
573 574
        assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping)
575 576
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
577 578 579 580 581 582

        parallel_axis = matmul_col_dim_mapping
        group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
                                      parallel_axis, rank_id)
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
        ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr,
                                  identity_var_dist_attr)
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
        ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr,
                                    out_var_dist_attr)

598 599 600 601 602 603 604 605
        intermediate_var_0 = main_block.create_var(
            name=unique_name.generate_with_ignorable_key(".".join(
                ["c_identity", 'tmp'])),
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=X_var.stop_gradient)
Z
zhaoyingli 已提交
606 607 608
        # set intermediate_var_0's dist_attr with X_var's dist_attr
        ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
                                             identity_var_dist_attr)
609 610 611 612 613 614 615 616 617 618 619 620 621

        check_variable_and_dtype(
            X_var, 'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')

        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
622
                OP_ROLE_KEY: src_op.attr('op_role')
623
            })
Z
zhaoyingli 已提交
624 625
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
626 627 628 629 630 631 632 633 634

        check_variable_and_dtype(intermediate_var_0, 'x',
                                 ['float16', 'float32', 'float64'], 'linear')
        check_dtype(intermediate_var_0.dtype, 'dtype',
                    ['float16', 'float32', 'float64'], 'linear')
        attrs = {
            'transpose_X': False,
            'transpose_Y': False,
            'alpha': 1,
635
            OP_ROLE_KEY: src_op('op_role')
636 637
        }
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
638 639 640 641
        matmul_op = main_block.append_op(type='matmul',
                                         inputs=inputs,
                                         outputs={'Out': Out_var},
                                         attrs=attrs)
Z
zhaoyingli 已提交
642 643 644 645 646 647 648
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
649
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
            op_dist_attr)
        identity_op_dist_attr.set_input_dist_attr(input_varname,
                                                  input_dist_attr)
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
        identity_op_dist_attr.set_output_dist_attr(output_varname,
                                                   input_dist_attr)
        # set op dist attr
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
668
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        for input_varname in matmul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
                    input_varname)
                assert input_dist_attr is not None, "dist_attr is {}".format(
                    op_dist_attr)
                matmul_op_dist_attr.set_input_dist_attr(input_varname,
                                                        input_dist_attr)
            else:
                input_var = main_block.var(input_varname)
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
                    input_var)
                matmul_op_dist_attr.set_input_dist_attr(input_varname,
                                                        tensor_dist_attr)
        # output
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
        assert output_dist_attr is not None, "dist_attr is {}".format(
            op_dist_attr)
        matmul_op_dist_attr.set_output_dist_attr(output_varname,
                                                 output_dist_attr)
        # set op dist attr
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
694 695

        # init param sync
696
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
697
            _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
698 699 700 701 702
                             rank_id)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
703

704 705 706

# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
707

708
    def __init__(self, name):
709
        super(DistributedMatmulImpl1, self).__init__(name)
710
        self._forward_implemented = True
711
        self._backward_implemented = True
712

713 714 715
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
716 717 718 719 720 721
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
722 723
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
                y_dims_mapping[-1]):
724 725 726 727 728 729 730
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

731 732 733
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
734 735 736 737 738 739 740 741 742 743
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

744
    def is_auto_compatible(self, dist_op):
745 746
        if (not self.is_input_compatible(dist_op)) or \
            (not self.is_output_compatible(dist_op)):
747
            return False
748

749
        if not _is_auto_compatible_for_matmul(dist_op):
750 751 752 753
            return False

        return True

754
    def update_dims_mapping(self, dist_op):
755
        changed = False
756
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
757 758 759 760
        if dim_changed:
            changed = True
        return changed

761 762 763 764 765 766
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

767
        dist_op_context = ctx.dist_op_context
768 769 770 771
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
772
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
773 774 775 776
        assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
            str(src_op))

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
777 778
        if rank_id not in op_dist_attr.process_mesh.processes:
            rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
779 780
                                              rank_id)

781
        # check validation of inputs / outputs
782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
                input_name)
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
                output_name)
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
                output_name)

        X_var = main_block.var(kwargs['X'][0])
        Weight_var = main_block.var(kwargs['Y'][0])
        Out_var = main_block.var(kwargs['Out'][0])

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
802
            Weight_var.name)[-2]
803 804
        assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping)
805 806
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
807 808 809 810 811 812 813 814 815 816 817 818 819 820

        parallel_axis = matmul_row_dim_mapping
        group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
                                      parallel_axis, rank_id)
        group = new_process_group(group_ranks)

        check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'],
                                 'linear')
        check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
                    'linear')
        attrs = {
            'transpose_X': False,
            'transpose_Y': False,
            'alpha': 1,
821
            OP_ROLE_KEY: src_op.attr('op_role')
822 823
        }
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
824 825 826 827 828 829 830 831 832

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
        ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr,
                                out_var_dist_attr)

833
        intermediate_var_0 = main_block.create_var(
834 835
            name=unique_name.generate_with_ignorable_key(".".join(
                ["c_allreduce_sum", 'tmp'])),
836 837 838 839 840 841 842
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
            need_check_feed=Out_var.desc.need_check_feed())
Z
zhaoyingli 已提交
843 844 845
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
        ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
                                             out_var_dist_attr)
846

847 848 849 850
        matmul_op = main_block.append_op(type='matmul',
                                         inputs=inputs,
                                         outputs={'Out': intermediate_var_0},
                                         attrs=attrs)
Z
zhaoyingli 已提交
851 852
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
853 854 855 856 857 858 859 860

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
861 862
                'use_model_parallel': True,
                OP_ROLE_KEY: src_op.attr('op_role')
863
            })
Z
zhaoyingli 已提交
864 865 866 867 868 869 870
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmul
        matmul_op_dist_attr = OperatorDistributedAttribute()
        matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
871
        matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889
        matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
                op_dist_attr)
            matmul_op_dist_attr.set_input_dist_attr(input_varname,
                                                    input_dist_attr)
        output_varname = matmul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
            op_dist_attr)
        matmul_op_dist_attr.set_output_dist_attr(output_varname,
                                                 output_dist_attr)
        ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
890
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
891 892 893 894 895 896 897 898 899 900 901 902 903 904 905
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
            input_var = main_block.var(input_varname)
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
            allreduce_op_dist_attr.set_input_dist_attr(input_varname,
                                                       tensor_dist_attr)
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
                op_dist_attr)
            allreduce_op_dist_attr.set_output_dist_attr(output_varname,
                                                        output_dist_attr)
        ctx.set_op_dist_attr_for_program(c_allreduce_sum_op,
                                         allreduce_op_dist_attr)
906 907

        # init param sync
908
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
909
            _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
910 911 912 913 914
                             rank_id)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
915

916

917
# ReplicateParallel
918
class DistributedMatmulImpl2(DistributedOperatorImpl):
919

920
    def __init__(self, name):
921
        super(DistributedMatmulImpl2, self).__init__(name)
922

923 924 925
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
926 927 928 929 930 931 932
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
933 934
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
                x_dims_mapping[-2]):
935 936 937 938
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
939 940
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
                y_dims_mapping[-2]):
941 942 943 944
            return False

        return True

945 946 947
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
948 949 950 951 952
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
953 954
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
                out_dims_mapping[-2]):
955 956 957 958
            return False

        return True

959
    def is_auto_compatible(self, dist_op):
960 961
        if (not self.is_input_compatible(dist_op)) or \
            (not self.is_output_compatible(dist_op)):
962 963
            return False

964
        if not _is_auto_compatible_for_matmul(dist_op):
965 966 967 968
            return False

        return True

969
    def update_dims_mapping(self, dist_op):
970
        changed = False
971
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
972 973 974 975
        if dim_changed:
            changed = True
        return changed

976 977 978 979
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

980 981 982 983
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

984 985 986 987 988 989 990 991 992

register_distributed_operator_impl("matmul",
                                   DistributedMatmulImpl0("column_parallel"))
register_distributed_operator_impl("matmul",
                                   DistributedMatmulImpl1("row_parallel"))
register_distributed_operator_impl("matmul",
                                   DistributedMatmulImpl2("replicate_parallel"))


993
class DistributedMatmulV2(DistributedOperatorImplContainer):
994

995 996
    def __init__(self, op_type):
        super(DistributedMatmulV2, self).__init__(op_type)
997 998


999
register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
1000 1001


1002 1003
# ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl):
1004

1005
    def __init__(self, name):
1006
        super(DistributedMatmulV2Impl0, self).__init__(name)
1007
        self._forward_implemented = True
1008
        self._backward_implemented = True
1009

1010 1011 1012
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1013 1014 1015 1016 1017 1018
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_shard(x_dims_mapping[-1]):
            return False
1019 1020
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
                y_dims_mapping[-1]):
1021 1022 1023 1024 1025 1026
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1027 1028 1029
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1030 1031 1032 1033 1034 1035 1036 1037 1038
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1039
    def is_auto_compatible(self, dist_op):
1040 1041
        if (not self.is_input_compatible(dist_op)) or \
            (not self.is_output_compatible(dist_op)):
1042 1043
            return False

1044
        if not _is_auto_compatible_for_matmul(dist_op):
1045 1046 1047 1048
            return False

        return True

1049
    def update_dims_mapping(self, dist_op):
1050
        changed = False
1051
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1052 1053 1054 1055
        if dim_changed:
            changed = True
        return changed

1056 1057 1058 1059 1060 1061
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1062
        dist_op_context = ctx.dist_op_context
1063 1064 1065 1066
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1067
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1068 1069 1070 1071
        assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
            str(src_op))

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1072 1073
        if rank_id not in op_dist_attr.process_mesh.processes:
            rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
1074 1075
                                              rank_id)

1076
        # check validation of inputs / outputs
1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
                input_name)
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
                output_name)
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
                output_name)

        X_var = main_block.var(kwargs['X'][0])
1092
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
1093 1094 1095 1096
        Out_var = main_block.var(kwargs['Out'][0])

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
1097
            Weight_var.name)[-1]
1098 1099
        assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping)
1100 1101
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
1102 1103 1104 1105 1106 1107

        parallel_axis = matmul_col_dim_mapping
        group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
                                      parallel_axis, rank_id)
        group = new_process_group(group_ranks)

Z
zhaoyingli 已提交
1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
        ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr,
                                  identity_var_dist_attr)
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
        ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr,
                                    out_var_dist_attr)

1123 1124 1125 1126 1127 1128 1129 1130
        intermediate_var_0 = main_block.create_var(
            name=unique_name.generate_with_ignorable_key(".".join(
                ["c_identity", 'tmp'])),
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=X_var.stop_gradient)
Z
zhaoyingli 已提交
1131 1132 1133
        # set intermediate_var_0's dist_attr with X_var's dist_attr
        ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
                                             identity_var_dist_attr)
1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145

        check_variable_and_dtype(
            X_var, 'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
1146
                OP_ROLE_KEY: src_op.attr('op_role'),
1147
            })
Z
zhaoyingli 已提交
1148 1149
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)
1150 1151 1152 1153 1154

        check_variable_and_dtype(intermediate_var_0, 'x',
                                 ['float16', 'float32', 'float64'], 'linear')
        check_dtype(intermediate_var_0.dtype, 'dtype',
                    ['float16', 'float32', 'float64'], 'linear')
1155 1156 1157 1158 1159
        attrs = {
            'trans_x': False,
            'trans_y': False,
            OP_ROLE_KEY: src_op.attr('op_role')
        }
1160
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
1161 1162 1163 1164
        matmul_v2_op = main_block.append_op(type='matmul_v2',
                                            inputs=inputs,
                                            outputs={'Out': Out_var},
                                            attrs=attrs)
Z
zhaoyingli 已提交
1165 1166 1167 1168 1169 1170 1171
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1172
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
            op_dist_attr)
        identity_op_dist_attr.set_input_dist_attr(input_varname,
                                                  input_dist_attr)
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
        identity_op_dist_attr.set_output_dist_attr(output_varname,
                                                   input_dist_attr)
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1190
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1191 1192 1193 1194 1195 1196 1197
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
                    input_varname)
                assert input_dist_attr is not None, "dist_attr is {}".format(
                    op_dist_attr)
1198 1199
                matmulv2_op_dist_attr.set_input_dist_attr(
                    input_varname, input_dist_attr)
Z
zhaoyingli 已提交
1200 1201 1202 1203
            else:
                input_var = main_block.var(input_varname)
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
                    input_var)
1204 1205
                matmulv2_op_dist_attr.set_input_dist_attr(
                    input_varname, tensor_dist_attr)
Z
zhaoyingli 已提交
1206 1207 1208 1209 1210 1211 1212
        for output_varname in matmul_v2_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
                op_dist_attr)
            matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
                                                       output_dist_attr)
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
1213 1214

        # init param sync
1215
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1216
            _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
1217 1218 1219 1220 1221
                             rank_id)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1222 1223 1224 1225


# RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl):
1226

1227
    def __init__(self, name):
1228
        super(DistributedMatmulV2Impl1, self).__init__(name)
1229
        self._forward_implemented = True
1230
        self._backward_implemented = True
1231

1232 1233 1234
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1235 1236 1237 1238 1239 1240
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1241 1242
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
                y_dims_mapping[-1]):
1243 1244 1245 1246 1247 1248 1249
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1250 1251 1252
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1253 1254 1255 1256 1257 1258 1259 1260 1261 1262
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

1263
    def is_auto_compatible(self, dist_op):
1264 1265
        if (not self.is_input_compatible(dist_op)) or \
            (not self.is_output_compatible(dist_op)):
1266 1267
            return False

1268
        if not _is_auto_compatible_for_matmul(dist_op):
1269 1270 1271 1272
            return False

        return True

1273
    def update_dims_mapping(self, dist_op):
1274
        changed = False
1275
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1276 1277 1278 1279
        if dim_changed:
            changed = True
        return changed

1280 1281 1282 1283 1284 1285
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

1286
        dist_op_context = ctx.dist_op_context
1287 1288 1289 1290
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
1291
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
1292 1293 1294 1295
        assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
            str(src_op))

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
1296 1297
        if rank_id not in op_dist_attr.process_mesh.processes:
            rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
1298 1299
                                              rank_id)

1300
        # check validation of inputs / outputs
1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
                input_name)
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
                output_name)
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
                output_name)

        X_var = main_block.var(kwargs['X'][0])
1316
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
1317 1318 1319 1320
        Out_var = main_block.var(kwargs['Out'][0])

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
1321
            Weight_var.name)[-2]
1322 1323
        assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping)
1324 1325
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes
1326 1327 1328 1329 1330 1331 1332 1333 1334 1335

        parallel_axis = matmul_row_dim_mapping
        group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
                                      parallel_axis, rank_id)
        group = new_process_group(group_ranks)

        check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'],
                                 'linear')
        check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
                    'linear')
1336 1337 1338 1339 1340
        attrs = {
            'trans_x': False,
            'trans_y': False,
            OP_ROLE_KEY: src_op.attr('op_role')
        }
1341
        inputs = {'X': X_var, 'Y': Weight_var}
Z
zhaoyingli 已提交
1342 1343 1344 1345 1346 1347 1348 1349 1350

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
        ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr,
                                out_var_dist_attr)

1351
        intermediate_var_0 = main_block.create_var(
1352 1353
            name=unique_name.generate_with_ignorable_key(".".join(
                ["c_allreduce_sum", 'tmp'])),
1354 1355 1356 1357 1358 1359 1360
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
            need_check_feed=Out_var.desc.need_check_feed())
Z
zhaoyingli 已提交
1361 1362 1363
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
        ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
                                             out_var_dist_attr)
1364

1365 1366 1367 1368
        matmul_v2_op = main_block.append_op(type='matmul_v2',
                                            inputs=inputs,
                                            outputs={'Out': intermediate_var_0},
                                            attrs=attrs)
Z
zhaoyingli 已提交
1369 1370
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)
1371 1372 1373 1374 1375 1376 1377 1378

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
1379 1380
                'use_model_parallel': True,
                OP_ROLE_KEY: src_op.attr('op_role')
1381
            })
Z
zhaoyingli 已提交
1382 1383 1384 1385 1386 1387 1388
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1389
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in matmul_v2_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
                op_dist_attr)
            matmulv2_op_dist_attr.set_input_dist_attr(input_varname,
                                                      input_dist_attr)
        output_varname = matmul_v2_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
            op_dist_attr)
        matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
                                                   output_dist_attr)
        ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
1408
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
Z
zhaoyingli 已提交
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
            input_var = main_block.var(input_varname)
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
            allreduce_op_dist_attr.set_input_dist_attr(input_varname,
                                                       tensor_dist_attr)
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
                op_dist_attr)
            allreduce_op_dist_attr.set_output_dist_attr(output_varname,
                                                        output_dist_attr)
        ctx.set_op_dist_attr_for_program(c_allreduce_sum_op,
                                         allreduce_op_dist_attr)
1424 1425

        # init param sync
1426
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
1427
            _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
1428 1429 1430 1431 1432
                             rank_id)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
1433 1434


1435
# ReplicateParallel
1436
class DistributedMatmulV2Impl2(DistributedOperatorImpl):
1437

1438
    def __init__(self, name):
1439
        super(DistributedMatmulV2Impl2, self).__init__(name)
1440

1441 1442 1443
    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1444 1445 1446 1447 1448 1449 1450
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
1451 1452
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
                x_dims_mapping[-2]):
1453 1454 1455 1456
            return False

        if is_dim_shard(y_dims_mapping[-1]):
            return False
1457 1458
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
                y_dims_mapping[-2]):
1459 1460 1461
            return False
        return True

1462 1463 1464 1465 1466
    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
1467 1468 1469 1470 1471
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
1472 1473
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
                out_dims_mapping[-2]):
1474 1475 1476 1477
            return False

        return True

1478
    def is_auto_compatible(self, dist_op):
1479 1480
        if (not self.is_input_compatible(dist_op)) or \
            (not self.is_output_compatible(dist_op)):
1481 1482
            return False

1483
        if not _is_auto_compatible_for_matmul(dist_op):
1484 1485 1486 1487
            return False

        return True

1488
    def update_dims_mapping(self, dist_op):
1489
        changed = False
1490
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
1491 1492 1493 1494
        if dim_changed:
            changed = True
        return changed

1495 1496 1497 1498
    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

1499 1500 1501 1502
    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)

1503

1504 1505 1506 1507
register_distributed_operator_impl("matmul_v2",
                                   DistributedMatmulV2Impl0("column_parallel"))
register_distributed_operator_impl("matmul_v2",
                                   DistributedMatmulV2Impl1("row_parallel"))
1508
register_distributed_operator_impl(
1509
    "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel"))
1510 1511 1512


class DistributedMul(DistributedOperatorImplContainer):
1513

1514 1515 1516 1517 1518 1519 1520 1521 1522
    def __init__(self, op_type):
        super(DistributedMul, self).__init__(op_type)


register_distributed_operator_impl_container(DistributedMul("mul"))


# ColumnParallel
class DistributedMulImpl0(DistributedOperatorImpl):
1523

1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537
    def __init__(self, name):
        super(DistributedMulImpl0, self).__init__(name)
        self._forward_implemented = True
        self._backward_implemented = True

    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_shard(x_dims_mapping[-1]):
            return False
1538 1539
        if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
                y_dims_mapping[-1]):
1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664
            return False
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_replicate(out_dims_mapping[-1]):
            return False
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
        if (not self.is_input_compatible(dist_op)) or \
            (not self.is_output_compatible(dist_op)):
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
        assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
            str(src_op))

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
        if rank_id not in op_dist_attr.process_mesh.processes:
            rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
                                              rank_id)

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
                input_name)
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
                output_name)
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
                output_name)

        X_var = main_block.var(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block.var(kwargs['Out'][0])

        # TODO infer logic comm presentation
        matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
            Weight_var.name)[-1]
        assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_col_dim_mapping)
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes

        parallel_axis = matmul_col_dim_mapping
        group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
                                      parallel_axis, rank_id)
        group = new_process_group(group_ranks)

        # infer new var shape with op dist attr
        x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var)
        assert x_tensor_dist_attr is not None
        identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
        assert identity_var_dist_attr is not None
        ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr,
                                  identity_var_dist_attr)
        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
        ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr,
                                    out_var_dist_attr)

        intermediate_var_0 = main_block.create_var(
            name=unique_name.generate_with_ignorable_key(".".join(
                ["c_identity", 'tmp'])),
            dtype=X_var.dtype,
            shape=X_var.shape,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=X_var.stop_gradient)
        # set intermediate_var_0's dist_attr with X_var's dist_attr
        ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
                                             identity_var_dist_attr)

        check_variable_and_dtype(
            X_var, 'tensor',
            ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
        c_identity_op = main_block.append_op(
            type='c_identity',
            inputs={'X': [X_var]},
            outputs={'Out': intermediate_var_0},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
                'use_model_parallel': True,
1665
                OP_ROLE_KEY: src_op.attr('op_role')
1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676
            })
        if intermediate_var_0.shape != ref_shape_x:
            intermediate_var_0.desc.set_shape(ref_shape_x)

        check_variable_and_dtype(intermediate_var_0, 'x',
                                 ['float16', 'float32', 'float64'], 'linear')
        check_dtype(intermediate_var_0.dtype, 'dtype',
                    ['float16', 'float32', 'float64'], 'linear')
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
1677 1678
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
            OP_ROLE_KEY: src_op.attr('op_role')
1679 1680
        }
        inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
1681 1682 1683 1684
        mul_op = main_block.append_op(type='mul',
                                      inputs=inputs,
                                      outputs={'Out': Out_var},
                                      attrs=attrs)
1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717
        if Out_var.shape != ref_shape_out:
            Out_var.desc.set_shape(ref_shape_out)

        # set dist op's dist_attr with serial op's dist_attr
        # c_identity
        identity_op_dist_attr = OperatorDistributedAttribute()
        identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        identity_op_dist_attr.impl_type = op_dist_attr.impl_type
        identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        # input
        input_varname = c_identity_op.desc.input_arg_names()[0]
        input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
        assert input_dist_attr is not None, "dist_attr is {}".format(
            op_dist_attr)
        identity_op_dist_attr.set_input_dist_attr(input_varname,
                                                  input_dist_attr)
        # output
        output_varname = c_identity_op.desc.output_arg_names()[0]
        identity_op_dist_attr.set_output_dist_attr(output_varname,
                                                   input_dist_attr)
        ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)

        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            if input_varname in src_op.desc.input_arg_names():
                input_dist_attr = op_dist_attr.get_input_dist_attr(
                    input_varname)
                assert input_dist_attr is not None, "dist_attr is {}".format(
                    op_dist_attr)
1718 1719
                matmulv2_op_dist_attr.set_input_dist_attr(
                    input_varname, input_dist_attr)
1720 1721 1722 1723
            else:
                input_var = main_block.var(input_varname)
                tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
                    input_var)
1724 1725
                matmulv2_op_dist_attr.set_input_dist_attr(
                    input_varname, tensor_dist_attr)
1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745
        for output_varname in mul_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
                op_dist_attr)
            matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
                                                       output_dist_attr)
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
            _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
                             rank_id)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# RowParallel
class DistributedMulImpl1(DistributedOperatorImpl):
1746

1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760
    def __init__(self, name):
        super(DistributedMulImpl1, self).__init__(name)
        self._forward_implemented = True
        self._backward_implemented = True

    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
        if is_dim_replicate(x_dims_mapping[-1]):
            return False
1761 1762
        if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
                y_dims_mapping[-1]):
1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in x_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
        if is_dim_shard(out_dims_mapping[-1]):
            return False
        # Other dimensions must be replicate except the batch dimension
        for mapping in out_dims_mapping[1:-1]:
            if is_dim_shard(mapping):
                return False
        return True

    def is_auto_compatible(self, dist_op):
        if (not self.is_input_compatible(dist_op)) or \
            (not self.is_output_compatible(dist_op)):
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        kwargs: inputname_mapping & outputname_mapping
        """

        dist_op_context = ctx.dist_op_context
        main_block = dist_op_context.work_block
        startup_block = dist_op_context.startup_block
        src_op = dist_op_context.cur_src_op
        rank_id = dist_op_context.rank_id
        op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
        assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
            str(src_op))

        # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
        if rank_id not in op_dist_attr.process_mesh.processes:
            rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
                                              rank_id)

        # check validation of inputs / outputs
        for input_name in src_op.desc.input_names():
            assert input_name in kwargs, "input [{}] is not given".format(
                input_name)
            assert len(kwargs[input_name]) == len(
                src_op.desc.input(input_name)
            ), "number of tensor for input [{}] is not match".format(input_name)
        for output_name in src_op.desc.output_names():
            assert output_name in kwargs, "input [{}] is not given".format(
                output_name)
            assert len(kwargs[output_name]) == len(
                src_op.desc.output(output_name)
            ), "number of tensor for input [{}] is not match".format(
                output_name)

        X_var = main_block.var(kwargs['X'][0])
        Weight_var = main_block._var_recursive(kwargs['Y'][0])
        Out_var = main_block.var(kwargs['Out'][0])

        # TODO infer logic comm presentation
        matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
            Weight_var.name)[-2]
        assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
            matmul_row_dim_mapping)
        process_mesh_shape = op_dist_attr.process_mesh.topology
        process_mesh_group = op_dist_attr.process_mesh.processes

        parallel_axis = matmul_row_dim_mapping
        group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
                                      parallel_axis, rank_id)
        group = new_process_group(group_ranks)

        check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'],
                                 'linear')
        check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
                    'linear')
        # attrs = {'trans_x': False, 'trans_y': False}
        attrs = {
            "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
1859 1860
            "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
            OP_ROLE_KEY: src_op.attr('op_role')
1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872
        }
        inputs = {'X': X_var, 'Y': Weight_var}

        # infer out var shape with op dist attr
        out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
        assert out_tensor_dist_attr is not None
        out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert out_var_dist_attr is not None
        ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr,
                                out_var_dist_attr)

        intermediate_var_0 = main_block.create_var(
1873 1874
            name=unique_name.generate_with_ignorable_key(".".join(
                ["c_allreduce_sum", 'tmp'])),
1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885
            shape=Out_var.shape,
            dtype=Out_var.dtype,
            type=Out_var.type,
            lod_level=Out_var.lod_level,
            persistable=False,
            is_data=False,
            need_check_feed=Out_var.desc.need_check_feed())
        # set intermediate_var_0's dist_attr with Out_var's dist_attr
        ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
                                             out_var_dist_attr)

1886 1887 1888 1889
        mul_op = main_block.append_op(type='mul',
                                      inputs=inputs,
                                      outputs={'Out': intermediate_var_0},
                                      attrs=attrs)
1890 1891 1892 1893 1894 1895 1896 1897 1898 1899
        if intermediate_var_0.shape != ref_shape:
            intermediate_var_0.desc.set_shape(ref_shape)

        c_allreduce_sum_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': intermediate_var_0},
            outputs={'Out': Out_var},
            attrs={
                'ring_id': group.id,
                'use_calc_stream': True,
1900 1901
                'use_model_parallel': True,
                OP_ROLE_KEY: src_op.attr('op_role')
1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957
            })
        if Out_var.shape != ref_shape:
            Out_var.desc.set_shape(ref_shape)

        # set dist op's dist_attr with serial op's dist_attr
        # matmulv2
        matmulv2_op_dist_attr = OperatorDistributedAttribute()
        matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
        matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in mul_op.desc.input_arg_names():
            input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
            assert input_dist_attr is not None, "dist_attr is {}".format(
                op_dist_attr)
            matmulv2_op_dist_attr.set_input_dist_attr(input_varname,
                                                      input_dist_attr)
        output_varname = mul_op.desc.output_arg_names()[0]
        output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
        assert output_dist_attr is not None, "dist_attr is {}".format(
            op_dist_attr)
        matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
                                                   output_dist_attr)
        ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)

        # allreduce
        allreduce_op_dist_attr = OperatorDistributedAttribute()
        allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
        allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
        allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
        for input_varname in c_allreduce_sum_op.desc.input_arg_names():
            input_var = main_block.var(input_varname)
            tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
            assert tensor_dist_attr is not None
            allreduce_op_dist_attr.set_input_dist_attr(input_varname,
                                                       tensor_dist_attr)
        for output_varname in c_allreduce_sum_op.desc.output_arg_names():
            output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
            assert output_dist_attr is not None, "dist_attr is {}".format(
                op_dist_attr)
            allreduce_op_dist_attr.set_output_dist_attr(output_varname,
                                                        output_dist_attr)
        ctx.set_op_dist_attr_for_program(c_allreduce_sum_op,
                                         allreduce_op_dist_attr)

        # init param sync
        if Weight_var.is_parameter and not op_dist_attr.is_recompute:
            _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
                             rank_id)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


# ReplicateParallel
class DistributedMulImpl2(DistributedOperatorImpl):
1958

1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971
    def __init__(self, name):
        super(DistributedMulImpl2, self).__init__(name)

    def is_input_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        x_name = op_desc.input('X')[0]
        y_name = op_desc.input('Y')[0]
        x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
        y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)

        if is_dim_shard(x_dims_mapping[-1]):
            return False
1972 1973
        if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
                x_dims_mapping[-2]):
1974 1975 1976
            return False
        if is_dim_shard(y_dims_mapping[-1]):
            return False
1977 1978
        if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
                y_dims_mapping[-2]):
1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991
            return False
        return True

    def is_output_compatible(self, dist_op):
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        op_desc = dist_op.serial_op.desc
        op_dist_attr = dist_op.dist_attr
        out_name = op_desc.output('Out')[0]
        out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

        if is_dim_shard(out_dims_mapping[-1]):
            return False
1992 1993
        if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
                out_dims_mapping[-2]):
1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028
            return False

        return True

    def is_auto_compatible(self, dist_op):
        if (not self.is_input_compatible(dist_op)) or \
            (not self.is_output_compatible(dist_op)):
            return False

        if not _is_auto_compatible_for_matmul(dist_op):
            return False

        return True

    def update_dims_mapping(self, dist_op):
        changed = False
        dim_changed = _update_dims_mapping_for_matmul(dist_op)
        if dim_changed:
            changed = True
        return changed

    @staticmethod
    def forward(ctx, *args, **kwargs):
        DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

    @staticmethod
    def backward(ctx, *args, **kwargs):
        _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)


register_distributed_operator_impl("mul",
                                   DistributedMulImpl0("column_parallel"))
register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel"))
register_distributed_operator_impl("mul",
                                   DistributedMulImpl2("replicate_parallel"))