interface.py 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy
16
import copy
17
import paddle
18
import paddle.fluid.core as core
19
from paddle.fluid.framework import Variable
J
Jiabin Yang 已提交
20
from paddle.fluid.framework import _non_static_mode
21 22 23 24 25
from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor
from .dist_op import DistributedModule
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
26 27 28


def _static_mode_check():
J
Jiabin Yang 已提交
29
    if _non_static_mode():
30 31
        raise RuntimeError("Auto-parallel only supports static mode for now, "
                           "please use paddle.enable_static() first.")
32

33

34
def shard_tensor(x, dist_attr=None):
35 36 37 38
    """
    Add distributed attributes for a tensors.

    Args:
39 40 41
        x (Tensor): the tensor to be sharded.
        dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow:
            "process_mesh": a nested list an to describe the mesh topology of logical processes.
42 43
            "dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension
                `i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`,
44 45
                where -1 means that tensor dimension is not split.
            Both process_mesh and dims_mapping are optional and users can specify as need.
46 47

    Returns:
48
        Tensor: the tensor `x` annotated with distributed attributes.
49 50 51 52 53 54

    Examples:
        .. code-block:: python

            import paddle
            import paddle.distributed as dist
55

56 57 58
            paddle.enable_static()

            x = paddle.ones([4, 6])
59 60
            dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]],
                                            "dims_mapping": [0, -1]})
61 62 63

    """
    _static_mode_check()
64 65 66 67 68 69
    assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
        "The type of dist_attr must be None, dict or TensorDistributedAttribute."
    dist_tensor = DistributedTensor(x, dist_attr)
    dist_tensor.dist_attr.mark_annotated_as(dist_attr)
    default_dist_ctx = get_default_distributed_context()
    default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
70 71 72
    return x


73
def shard_op(op_fn, dist_attr=None):
74 75 76 77
    """
    Call a functioin and add distributed attributes for ops added by the function.

    Args:
78
        op_fn (callable): a callable operator or module to be sharded.
79 80
        dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into
            two categories. The first category decsribes the distributed attributes shared by all inputs and
81 82 83
            outputs, and only `process_mesh` can be specified now. The second category describes distributed
            attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are
            optional and users can specify them as need. Note that `process_mesh` for operators must be the
84
            same as these process_meshes for inputs and outputs.
85 86

    Returns:
87
        list: the outputs of the function `op_fn`, which are annotated with distributed attributes.
88 89 90 91 92 93 94 95

    Examples:
        .. code-block:: python

            import paddle
            import paddle.distributed as dist

            paddle.enable_static()
96

97 98
            x = paddle.ones([4, 6])
            y = paddle.zeros([4, 6])
99 100 101 102 103 104 105
            dist_add = dist.shard_op(paddle.add,
                                     dist_attr={
                                         "process_mesh": [[2, 3, 1], [0, 4, 5]],
                                         x: {"dims_mapping": [-1, 0]},
                                         y: {"dims_mapping": [0, -1]}
                                     })
            dist_add(x, y)
106 107 108

    """
    _static_mode_check()
109 110 111 112
    assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
        "The type of dist_attr must be dict or OperatorDistributedAttribute."
    dist_module = DistributedModule(op_fn, dist_attr)
    return dist_module