test_auto_parallel_searcher.py 9.0 KB
Newer Older
C
caozhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

17 18 19
import os
import copy
import json
C
caozhou 已提交
20 21 22 23 24 25 26
import unittest

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
27
from paddle.distributed import fleet
28
from paddle.distributed.fleet import auto
29 30 31
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.utils import SerialProgramInfo
from paddle.distributed.auto_parallel.planner import PlanSpace, PlanFilter
C
caozhou 已提交
32
from paddle.distributed.auto_parallel.dist_context import DistributedContext
33
from paddle.distributed.auto_parallel.utils import get_all_distributed_main_program
C
caozhou 已提交
34 35 36 37 38 39 40 41 42
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.utils import update_op_dims_mapping_by_default_dist_impl
from paddle.distributed.auto_parallel.utils import update_op_dims_mapping_by_elementwise_like_dist_impl

paddle.enable_static()


class MLPLayer(nn.Layer):
43

C
caozhou 已提交
44 45 46 47 48 49 50
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
51 52
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
C
caozhou 已提交
53 54
        bias_attr = None

55 56 57 58 59 60 61 62
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
C
caozhou 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)

    def forward(self, input):
        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = paddle.unsqueeze(out, axis=0)
        out = paddle.reshape(out, [4, 1024])
        return out


def mlp_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
81 82 83 84 85 86
        input = static.data(name="input",
                            shape=[batch_size, hidden_size],
                            dtype='float32')
        label = static.data(name="label",
                            shape=[batch_size, 1],
                            dtype='float32')
C
caozhou 已提交
87
        loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
88 89 90
        mlp = MLPLayer(hidden_size=hidden_size,
                       intermediate_size=4 * hidden_size,
                       initializer_range=0.02)
C
caozhou 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108

        predict = mlp(input)
        error_cost = loss_func(predict, label)
        loss = paddle.mean(error_cost)

    return loss, train_program, start_program


def set_default_dist_attr(program, dist_context, process_mesh):
    ops = program.global_block().ops
    vars = program.global_block().vars
    for op in ops:
        op_dist_attr = OperatorDistributedAttribute()
        op_dist_attr.process_mesh = process_mesh
        for var_name in op.input_arg_names:
            tensor_dist_attr = TensorDistributedAttribute()
            tensor_dist_attr.process_mesh = process_mesh
            tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape]
109 110
            dist_context.set_tensor_dist_attr_for_program(
                vars[var_name], tensor_dist_attr)
C
caozhou 已提交
111 112 113 114 115 116 117
            op_dist_attr.set_input_dims_mapping(var_name,
                                                tensor_dist_attr.dims_mapping)

        for var_name in op.output_arg_names:
            tensor_dist_attr = TensorDistributedAttribute()
            tensor_dist_attr.process_mesh = process_mesh
            tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape]
118 119
            dist_context.set_tensor_dist_attr_for_program(
                vars[var_name], tensor_dist_attr)
C
caozhou 已提交
120 121 122 123 124 125 126
            op_dist_attr.set_output_dims_mapping(var_name,
                                                 tensor_dist_attr.dims_mapping)
        dist_context.set_op_dist_attr_for_program(op, op_dist_attr)

    dist_context.add_process_mesh(process_mesh)


127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
def check_process_meshes(processes):
    result = PlanSpace.enum_process_mesh_topology(processes)
    if result:
        return True
    return False


def check_pipeline_enumerater(program, process_mesh_topology):
    valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
        program, process_mesh_topology, True)
    if valid_dist_attr_dict and len(
            pipeline_process_meshes) > 1 and not global_process_mesh:
        return True
    return False


def check_nonpipeline_enumerater(program, process_mesh_topology):
    valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
        program, process_mesh_topology, False)
    if valid_dist_attr_dict and not pipeline_process_meshes and global_process_mesh:
        return True
    return False


C
caozhou 已提交
151
class TestMLPSearcher(unittest.TestCase):
152

C
caozhou 已提交
153 154 155 156 157 158 159 160 161 162 163
    def test_update(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        _, train_program, startup_program = mlp_forward(train_program,
                                                        startup_program)
        global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
        dist_context = DistributedContext()
        set_default_dist_attr(train_program, dist_context, global_process_mesh)
        ops = train_program.global_block().ops
        vars = train_program.global_block().vars
        from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
164
        from paddle.distributed.auto_parallel.operators.common import is_elementwise_op
C
caozhou 已提交
165 166 167 168 169 170 171 172
        from paddle.distributed.auto_parallel.dist_op import DistributedOperator

        for op in ops:
            dist_op_impl_container = get_distributed_operator_impl_container(
                op.type)
            if dist_op_impl_container is None:
                op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
                dist_op = DistributedOperator(op, op_dist_attr)
173
                if is_elementwise_op(op.type):
C
caozhou 已提交
174 175 176 177 178 179
                    changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
                        dist_op)
                    self.assertFalse(changed)

                    dist_op.dist_attr.set_output_dims_mapping(
                        op.output_arg_names[0], [0] + [
180
                            -1 for i in range(
C
caozhou 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
                                1, len(vars[op.output_arg_names[0]].shape))
                        ])
                    try:
                        changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
                            dist_op)
                    except:
                        continue
                    self.assertTrue(changed)
                else:
                    changed = update_op_dims_mapping_by_default_dist_impl(
                        dist_op)
                    self.assertFalse(changed)

                    dist_op.dist_attr.set_output_dims_mapping(
                        op.output_arg_names[0], [0] + [
196
                            -1 for i in range(
C
caozhou 已提交
197 198 199 200 201 202 203 204 205
                                1, len(vars[op.output_arg_names[0]].shape))
                        ])
                    try:
                        changed = update_op_dims_mapping_by_default_dist_impl(
                            dist_op)
                    except:
                        continue
                    self.assertTrue(changed)

206 207 208 209 210 211 212 213 214 215 216 217 218 219
    def test_enumerater_and_checker(self):
        processes = 4
        self.assertTrue(check_process_meshes(processes))

        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        _, train_program, startup_program = mlp_forward(train_program,
                                                        startup_program)
        process_mesh_topology = [4]
        self.assertTrue(
            check_pipeline_enumerater(train_program, process_mesh_topology))
        self.assertTrue(
            check_nonpipeline_enumerater(train_program, process_mesh_topology))

C
caozhou 已提交
220 221 222

if __name__ == "__main__":
    unittest.main()