test_dist_matmul.py 12.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import paddle
18
from paddle.distributed.fleet import auto
19 20 21 22 23
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward

paddle.enable_static()

24
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
25 26 27 28 29


def init_x_row(trans_x):
    if trans_x:
        x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32')
30 31
        auto.shard_tensor(x, mesh, ["x", "y", None])

32 33 34
        return x
    else:
        x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32')
35 36
        auto.shard_tensor(x, mesh, ["x", None, "y"])

37 38 39 40 41 42
        return x


def init_x_col(trans_x):
    if trans_x:
        x = paddle.static.data(name='x', shape=[6, 8], dtype='float32')
43 44
        auto.shard_tensor(x, mesh, [None, "x"])

45 46 47
        return x
    else:
        x = paddle.static.data(name='x', shape=[8, 6], dtype='float32')
48 49
        auto.shard_tensor(x, mesh, ["x", None])

50 51 52 53 54 55
        return x


def init_y_row(trans_y):
    if trans_y:
        y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
56 57
        auto.shard_tensor(y, mesh, [None, "y"])

58 59 60
        return y
    else:
        y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
61 62
        auto.shard_tensor(y, mesh, ["y", None])

63 64 65 66 67 68
        return y


def init_y_col(trans_y):
    if trans_y:
        y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
69 70
        auto.shard_tensor(y, mesh, ["y", None])

71 72 73
        return y
    else:
        y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
74 75
        auto.shard_tensor(y, mesh, [None, "y"])

76 77 78 79 80 81 82 83 84 85 86
        return y


def matmul_dp2mp2(init_x, init_y, trans_x, trans_y):
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = init_x(trans_x)
        y = init_y(trans_y)
        x.stop_gradient = False
        y.stop_gradient = False
K
kangguangli 已提交
87
        out = paddle.matmul(x, y, transpose_x=trans_x, transpose_y=trans_y)
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
        loss = paddle.mean(out)
    return main_program, start_program, loss


def matmulv2_dp2mp2(init_x, init_y, trans_x, trans_y):
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = init_x(trans_x)
        y = init_y(trans_y)
        x.stop_gradient = False
        y.stop_gradient = False
        out = paddle.matmul(x, y, transpose_x=trans_x, transpose_y=trans_y)
        loss = paddle.mean(out)
    return main_program, start_program, loss


def parallelizer(program_func, *args, **kwargs):
106 107 108 109 110
    from paddle.distributed.auto_parallel.static.completion import Completer
    from paddle.distributed.auto_parallel.static.dist_context import (
        DistributedContext,
    )
    from paddle.distributed.auto_parallel.static.partitioner import Partitioner
111 112 113 114 115 116 117 118 119 120 121 122 123 124

    main_program, start_program, loss = program_func(*args, **kwargs)

    dist_context = DistributedContext()
    completer = Completer(dist_context)
    completer.complete_forward_annotation(main_program)
    dist_context.block_state.parse_forward_blocks(main_program)

    with program_guard(main_program, start_program):
        append_backward(loss, distop_context=dist_context.dist_op_context)
    completer.complete_backward_annotation(main_program)
    dist_context.block_state.parse_backward_blocks(main_program)

    partitioner = Partitioner(dist_context, 0)
125 126 127
    dist_main_prog, _, _ = partitioner.partition(
        main_program, start_program, []
    )
128 129 130 131 132 133 134 135

    return dist_main_prog, dist_context


class TestDistMatmul(unittest.TestCase):
    def check_col_program(self, main_program, dist_ctx):
        # [0, -1] * [-1, 1] --> [0, 1]
        ref_ops = [
136
            "c_identity",
K
kangguangli 已提交
137
            "matmul_v2",
138 139 140
            "reduce_mean",
            "fill_constant",
            "reduce_mean_grad",
K
kangguangli 已提交
141
            "matmul_v2_grad",
142 143 144 145 146
        ]
        ops = []
        block = main_program.global_block()
        for op in block.ops:
            ops.append(op.type)
K
kangguangli 已提交
147
            if op.type == "matmul_v2":
148 149 150 151
                out_name = op.output('Out')[0]
                out_var = block.vars[out_name]
                op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
                assert op_dist_attr.impl_idx == 0
K
kangguangli 已提交
152
                assert op_dist_attr.impl_type == "matmul_v2"
153
                out_dims_mapping = op_dist_attr.get_output_dims_mapping(
154 155
                    out_name
                )
156 157
                assert out_dims_mapping == [0, 1]
                tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program(
158 159
                    out_var
                )
160
                assert tensor_dist_attr.dims_mapping == [0, 1]
K
kangguangli 已提交
161
            if op.type == "matmul_v2_grad":
162 163
                op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
                assert op_dist_attr.impl_idx == 0
K
kangguangli 已提交
164
                assert op_dist_attr.impl_type == "matmul_v2"
165 166 167 168 169 170

        assert ops == ref_ops

    def check_row_program(self, main_program, dist_ctx):
        # [0, -1, 1] * [1, -1] --> [0, -1, -1]
        ref_ops = [
K
kangguangli 已提交
171
            "matmul_v2",
172 173 174 175
            "c_allreduce_sum",
            "reduce_mean",
            "fill_constant",
            "reduce_mean_grad",
K
kangguangli 已提交
176
            "matmul_v2_grad",
177 178 179 180 181
        ]
        ops = []
        block = main_program.global_block()
        for op in block.ops:
            ops.append(op.type)
K
kangguangli 已提交
182
            if op.type == "matmul_v2":
183 184 185 186
                out_name = op.output('Out')[0]
                out_var = block.vars[out_name]
                op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
                assert op_dist_attr.impl_idx == 1
K
kangguangli 已提交
187
                assert op_dist_attr.impl_type == "matmul_v2"
188
                out_dims_mapping = op_dist_attr.get_output_dims_mapping(
189 190
                    out_name
                )
191 192
                assert out_dims_mapping == [0, -1, -1]
                tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program(
193 194
                    out_var
                )
195
                assert tensor_dist_attr.dims_mapping == [0, -1, -1]
K
kangguangli 已提交
196
            if op.type == "matmul_v2_grad":
197 198
                op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
                assert op_dist_attr.impl_idx == 1
K
kangguangli 已提交
199
                assert op_dist_attr.impl_type == "matmul_v2"
200 201 202 203 204
        assert ops == ref_ops


class TestDistMatmulCol(TestDistMatmul):
    def init(self, trans_x, trans_y):
205 206 207
        dist_main_prog, dist_ctx = parallelizer(
            matmul_dp2mp2, init_x_col, init_y_col, trans_x, trans_y
        )
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        return dist_main_prog, dist_ctx

    def test_matmul_col(self):
        dist_main_prog, dist_ctx = self.init(False, False)
        self.check_col_program(dist_main_prog, dist_ctx)

    def test_trans_x(self):
        dist_main_prog, dist_ctx = self.init(True, False)
        self.check_col_program(dist_main_prog, dist_ctx)

    def test_trans_y(self):
        dist_main_prog, dist_ctx = self.init(False, True)
        self.check_col_program(dist_main_prog, dist_ctx)

    def test_trans_x_trans_y(self):
        dist_main_prog, dist_ctx = self.init(True, True)
        self.check_col_program(dist_main_prog, dist_ctx)


class TestDistMatmulRow(TestDistMatmul):
    def init(self, trans_x, trans_y):
229 230 231
        dist_main_prog, dist_ctx = parallelizer(
            matmul_dp2mp2, init_x_row, init_y_row, trans_x, trans_y
        )
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
        return dist_main_prog, dist_ctx

    def test_matmul_row(self):
        dist_main_prog, dist_ctx = self.init(False, False)
        self.check_row_program(dist_main_prog, dist_ctx)

    def test_trans_x(self):
        dist_main_prog, dist_ctx = self.init(True, False)
        self.check_row_program(dist_main_prog, dist_ctx)

    def test_trans_y(self):
        dist_main_prog, dist_ctx = self.init(False, True)
        self.check_row_program(dist_main_prog, dist_ctx)

    def test_trans_x_trans_y(self):
        dist_main_prog, dist_ctx = self.init(True, True)
        self.check_row_program(dist_main_prog, dist_ctx)


class TestDistMatmulV2(unittest.TestCase):
    def check_col_program(self, main_program, dist_ctx):
        # [0, -1] * [-1, 1] --> [0, 1]
        ref_ops = [
255 256 257 258 259 260
            "c_identity",
            "matmul_v2",
            "reduce_mean",
            "fill_constant",
            "reduce_mean_grad",
            "matmul_v2_grad",
261 262 263 264 265 266 267 268 269 270 271 272
        ]
        ops = []
        block = main_program.global_block()
        for op in block.ops:
            ops.append(op.type)
            if op.type == "matmul_v2":
                out_name = op.output('Out')[0]
                out_var = block.vars[out_name]
                op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
                assert op_dist_attr.impl_idx == 0
                assert op_dist_attr.impl_type == "matmul_v2"
                out_dims_mapping = op_dist_attr.get_output_dims_mapping(
273 274
                    out_name
                )
275 276
                assert out_dims_mapping == [0, 1]
                tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program(
277 278
                    out_var
                )
279 280 281 282 283 284 285 286 287 288 289
                assert tensor_dist_attr.dims_mapping == [0, 1]
            if op.type == "matmul_v2_grad":
                op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
                assert op_dist_attr.impl_idx == 0
                assert op_dist_attr.impl_type == "matmul_v2"

        assert ops == ref_ops

    def check_row_program(self, main_program, dist_ctx):
        # [0, -1, 1] * [1, -1] --> [0, -1, -1]
        ref_ops = [
290 291 292 293 294 295
            "matmul_v2",
            "c_allreduce_sum",
            "reduce_mean",
            "fill_constant",
            "reduce_mean_grad",
            "matmul_v2_grad",
296 297 298 299 300 301 302 303 304 305 306 307
        ]
        ops = []
        block = main_program.global_block()
        for op in block.ops:
            ops.append(op.type)
            if op.type == "matmul_v2":
                out_name = op.output('Out')[0]
                out_var = block.vars[out_name]
                op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
                assert op_dist_attr.impl_idx == 1
                assert op_dist_attr.impl_type == "matmul_v2"
                out_dims_mapping = op_dist_attr.get_output_dims_mapping(
308 309
                    out_name
                )
310 311
                assert out_dims_mapping == [0, -1, -1]
                tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program(
312 313
                    out_var
                )
314 315 316 317 318 319 320 321 322 323
                assert tensor_dist_attr.dims_mapping == [0, -1, -1]
            if op.type == "matmul_v2_grad":
                op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
                assert op_dist_attr.impl_idx == 1
                assert op_dist_attr.impl_type == "matmul_v2"
        assert ops == ref_ops


class TestDistMatmulV2Col(TestDistMatmulV2):
    def init(self, trans_x, trans_y):
324 325 326
        dist_main_prog, dist_ctx = parallelizer(
            matmulv2_dp2mp2, init_x_col, init_y_col, trans_x, trans_y
        )
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
        return dist_main_prog, dist_ctx

    def test_matmul_col(self):
        dist_main_prog, dist_ctx = self.init(False, False)
        self.check_col_program(dist_main_prog, dist_ctx)

    def test_trans_x(self):
        dist_main_prog, dist_ctx = self.init(True, False)
        self.check_col_program(dist_main_prog, dist_ctx)

    def test_trans_y(self):
        dist_main_prog, dist_ctx = self.init(False, True)
        self.check_col_program(dist_main_prog, dist_ctx)

    def test_trans_x_trans_y(self):
        dist_main_prog, dist_ctx = self.init(True, True)
        self.check_col_program(dist_main_prog, dist_ctx)


class TestDistMatmulV2Row(TestDistMatmulV2):
    def init(self, trans_x, trans_y):
348 349 350
        dist_main_prog, dist_ctx = parallelizer(
            matmulv2_dp2mp2, init_x_row, init_y_row, trans_x, trans_y
        )
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
        return dist_main_prog, dist_ctx

    def test_matmul_row(self):
        dist_main_prog, dist_ctx = self.init(False, False)
        self.check_row_program(dist_main_prog, dist_ctx)

    def test_trans_x(self):
        dist_main_prog, dist_ctx = self.init(True, False)
        self.check_row_program(dist_main_prog, dist_ctx)

    def test_trans_y(self):
        dist_main_prog, dist_ctx = self.init(False, True)
        self.check_row_program(dist_main_prog, dist_ctx)

    def test_trans_x_trans_y(self):
        dist_main_prog, dist_ctx = self.init(True, True)
        self.check_row_program(dist_main_prog, dist_ctx)


if __name__ == "__main__":
    unittest.main()