test_dist_op_cost.py 17.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
16
import unittest
17 18

import paddle
19 20
from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.static.operators.common import (
21 22 23
    get_distributed_operator_impl_container,
    is_elementwise_op,
)
24
from paddle.distributed.fleet import auto
25 26 27 28 29 30 31
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward

paddle.enable_static()


def parallelizer(program_func, rank):
32 33 34 35
    from paddle.distributed.auto_parallel.static.completion import Completer
    from paddle.distributed.auto_parallel.static.dist_context import (
        DistributedContext,
    )
36 37 38 39 40 41 42 43 44 45 46 47

    main_program, startup_program, loss = program_func()

    # complete forward
    dist_context = DistributedContext()
    completer = Completer(dist_context)
    completer.complete_forward_annotation(main_program)
    dist_context.block_state.parse_forward_blocks(main_program)

    # generate backward and complete backward
    with paddle.static.program_guard(main_program, startup_program):
        params_grads = append_backward(
48 49
            loss, None, None, None, distop_context=dist_context.dist_op_context
        )
50 51 52
    completer.complete_backward_annotation(main_program)
    dist_context.block_state.parse_backward_blocks(main_program)

C
caozhou 已提交
53
    optimizer = paddle.optimizer.Adam(learning_rate=0.001)
54 55 56 57 58 59 60 61 62 63
    # generate opt and complete opt
    with program_guard(main_program, startup_program):
        optimize_ops = copy.deepcopy(optimizer).apply_gradients(params_grads)

    completer.complete_update_annotation(main_program)

    return main_program, dist_context


class TestDistOpCost(unittest.TestCase):
C
caozhou 已提交
64
    def test_dist_op_cost_part1(self):
65 66 67 68 69 70
        def make_program():
            main_program = paddle.static.Program()
            start_program = paddle.static.Program()
            with paddle.static.program_guard(main_program, start_program):
                x = paddle.static.data(name='x', shape=[4, 8], dtype='float32')
                x.stop_gradient = True
71 72 73
                label = paddle.static.data(
                    name="label", shape=[4, 1], dtype='float32'
                )
74
                label.stop_gradient = True
75 76 77
                auto.shard_tensor(
                    x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None]
                )
78
                tmp = paddle.fluid.layers.fill_constant_batch_size_like(
79 80
                    input=x, shape=[2, 8], value=1, dtype='float32'
                )
81
                weight_attr = paddle.ParamAttr()
C
caozhou 已提交
82
                linear = paddle.nn.Linear(8, 1, weight_attr=weight_attr)
83 84 85 86 87 88 89 90 91 92 93 94 95
                linear_out = linear(x)
                gelu_out = paddle.nn.functional.gelu(linear_out)
                # default op with dp
                tmp = paddle.static.nn.layer_norm(gelu_out)
                error_cost = paddle.nn.functional.square_error_cost(tmp, label)
                loss = paddle.mean(error_cost)
            return main_program, start_program, loss

        main_program, dist_context = parallelizer(make_program, 0)
        ops = main_program.global_block().ops
        cluster = Cluster()
        cluster.gen_default_config_cluster(device_count=2)
        for idx, op in enumerate(ops):
96 97 98 99 100
            if (
                op.type != "matmul_v2"
                and op.type != "matmul_v2_grad"
                and op.type != "sgd"
            ):
101 102
                dist_op = dist_context.get_dist_op_for_program(op)
                op_dist_attr = dist_op.dist_attr
103
                processes = op_dist_attr.process_mesh.process_ids
104 105
                if is_elementwise_op(op.type):
                    container = get_distributed_operator_impl_container(
106 107
                        "elementwise"
                    )
108 109
                else:
                    container = get_distributed_operator_impl_container(
110 111
                        op_dist_attr.impl_type
                    )
112 113

                dist_impl = container.impls[op_dist_attr.impl_idx]
114 115 116
                dist_op_cost = dist_impl.calc_cost(
                    op.attr('op_role'), dist_op, dist_context, cluster
                )
117 118
                self.assertTrue(dist_op_cost)

C
caozhou 已提交
119 120 121 122 123 124 125
    def test_dist_op_cost_part2(self):
        def make_program():
            main_program = paddle.static.Program()
            start_program = paddle.static.Program()
            with paddle.static.program_guard(main_program, start_program):
                x = paddle.static.data(name='x', shape=[4], dtype='float32')
                x.stop_gradient = True
126 127 128
                label = paddle.static.data(
                    name="label", shape=[8, 1], dtype='float32'
                )
C
caozhou 已提交
129
                label.stop_gradient = True
130 131 132 133 134 135 136 137 138
                auto.shard_tensor(
                    x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x"]
                )

                auto.shard_tensor(
                    label,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
C
caozhou 已提交
139 140
                # embedding
                tmp = paddle.fluid.layers.fill_constant_batch_size_like(
141 142
                    input=x, shape=[4], value=1, dtype='int32'
                )
C
caozhou 已提交
143 144 145 146 147 148
                embedding = paddle.nn.Embedding(10, 8)
                out = embedding(tmp)
                # row parallel embedding
                for op in main_program.global_block().ops:
                    if op.type == "lookup_table_v2":
                        W = main_program.global_block().vars[op.input("W")[0]]
149
                        auto.shard_tensor(
150 151 152 153
                            W,
                            auto.ProcessMesh([0, 1], dim_names=["x"]),
                            ["x", None],
                        )
154
                out = paddle.transpose(out, [1, 0])  # [8, 2] [-1, 0]
C
caozhou 已提交
155 156

                # matmul
157
                param1 = paddle.create_parameter(
158 159 160 161 162 163 164
                    [4, 8], paddle.float32
                )  # [2, 8] [0, -1]
                auto.shard_tensor(
                    param1,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
165
                param2 = paddle.create_parameter(
166 167 168 169 170 171 172
                    [8, 8], paddle.float32
                )  # [8, 4] [-1, 0]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, "x"],
                )
K
kangguangli 已提交
173
                out1 = paddle.matmul(out, param1)  # [8, 8] [-1, -1]
174
                tmp_param = paddle.create_parameter(
175 176 177 178 179 180 181
                    [8, 8], paddle.float32
                )  # [8, 8] [-1, -1]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, None],
                )
K
kangguangli 已提交
182
                tmp_out = paddle.matmul(out1, tmp_param)
183
                tmp_out = paddle.scale(tmp_out, 0.5)
K
kangguangli 已提交
184
                out2 = paddle.matmul(tmp_out, param2)  # [8, 4] [-1, 0]
C
caozhou 已提交
185

186
                out8 = paddle.transpose(out2, [1, 0])  # [4, 8] [0, -1]
C
caozhou 已提交
187 188 189 190

                # reshape
                out9 = paddle.reshape(out8, [8, 2, 4])  # [4, 2, 4] [0, -1, -1]
                tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
191 192 193
                out10 = paddle.reshape(
                    tmp_reshape_out, [8, 8]
                )  # [4, 8] [0, -1]
C
caozhou 已提交
194 195 196 197 198

                # softmax
                softmax = paddle.nn.Softmax()
                out11 = softmax(out10)
                error_cost = paddle.nn.functional.square_error_cost(
199 200
                    out11, label
                )
C
caozhou 已提交
201 202
                loss = paddle.mean(error_cost)
            return main_program, start_program, loss
203 204 205 206 207 208 209 210

        main_program, dist_context = parallelizer(make_program, 0)
        ops = main_program.global_block().ops
        cluster = Cluster()
        cluster.gen_default_config_cluster(device_count=2)
        for idx, op in enumerate(ops):
            dist_op = dist_context.get_dist_op_for_program(op)
            op_dist_attr = dist_op.dist_attr
211
            processes = op_dist_attr.process_mesh.process_ids
212 213
            if is_elementwise_op(op.type):
                container = get_distributed_operator_impl_container(
214 215
                    "elementwise"
                )
216 217
            else:
                container = get_distributed_operator_impl_container(
218 219
                    op_dist_attr.impl_type
                )
220 221

            dist_impl = container.impls[op_dist_attr.impl_idx]
222 223 224
            dist_op_cost = dist_impl.calc_cost(
                op.attr('op_role'), dist_op, dist_context, cluster
            )
225 226 227 228 229 230 231 232 233
            self.assertTrue(dist_op_cost)

    def test_dist_op_cost_part3(self):
        def make_program():
            main_program = paddle.static.Program()
            start_program = paddle.static.Program()
            with paddle.static.program_guard(main_program, start_program):
                x = paddle.static.data(name='x', shape=[4], dtype='float32')
                x.stop_gradient = True
234 235 236
                label = paddle.static.data(
                    name="label", shape=[8, 1], dtype='float32'
                )
237
                label.stop_gradient = True
238 239 240 241 242 243 244 245 246
                auto.shard_tensor(
                    x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x"]
                )

                auto.shard_tensor(
                    label,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
247 248
                # embedding
                tmp = paddle.fluid.layers.fill_constant_batch_size_like(
249 250
                    input=x, shape=[4], value=1, dtype='int32'
                )
251 252 253 254 255 256
                embedding = paddle.nn.Embedding(10, 8)
                out = embedding(tmp)
                # row parallel embedding
                for op in main_program.global_block().ops:
                    if op.type == "lookup_table_v2":
                        W = main_program.global_block().vars[op.input("W")[0]]
257
                        auto.shard_tensor(
258 259 260 261
                            W,
                            auto.ProcessMesh([0, 1], dim_names=["x"]),
                            ["x", None],
                        )
262
                out = paddle.transpose(out, [1, 0])  # [8, 2] [-1, 0]
263 264

                # matmul_v2
265
                param1 = paddle.create_parameter(
266 267 268 269 270 271 272
                    [4, 8], paddle.float32
                )  # [2, 8] [0, -1]
                auto.shard_tensor(
                    param1,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
273
                param2 = paddle.create_parameter(
274 275 276 277 278 279 280
                    [8, 8], paddle.float32
                )  # [8, 4] [-1, 0]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, "x"],
                )
281
                out1 = paddle.matmul(out, param1)  # [8, 8] [-1, -1]
282
                tmp_param = paddle.create_parameter(
283 284 285 286 287 288 289
                    [8, 8], paddle.float32
                )  # [8, 8] [-1, -1]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, None],
                )
290

291
                tmp_out = paddle.matmul(out1, tmp_param)
292
                tmp_out = paddle.scale(tmp_out, 0.5)
293 294
                out2 = paddle.matmul(tmp_out, param2)  # [8, 4] [-1, 0]

295
                out8 = paddle.transpose(out2, [1, 0])  # [4, 8] [0, -1]
296 297 298 299

                # reshape
                out9 = paddle.reshape(out8, [8, 2, 4])  # [4, 2, 4] [0, -1, -1]
                tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
300 301 302
                out10 = paddle.reshape(
                    tmp_reshape_out, [8, 8]
                )  # [4, 8] [0, -1]
303 304 305 306 307

                # softmax
                softmax = paddle.nn.Softmax()
                out11 = softmax(out10)
                error_cost = paddle.nn.functional.square_error_cost(
308 309
                    out11, label
                )
310 311 312 313 314 315 316 317 318 319
                loss = paddle.mean(error_cost)
            return main_program, start_program, loss

        main_program, dist_context = parallelizer(make_program, 0)
        ops = main_program.global_block().ops
        cluster = Cluster()
        cluster.gen_default_config_cluster(device_count=2)
        for idx, op in enumerate(ops):
            dist_op = dist_context.get_dist_op_for_program(op)
            op_dist_attr = dist_op.dist_attr
320
            processes = op_dist_attr.process_mesh.process_ids
321 322
            if is_elementwise_op(op.type):
                container = get_distributed_operator_impl_container(
323 324
                    "elementwise"
                )
325 326
            else:
                container = get_distributed_operator_impl_container(
327 328
                    op_dist_attr.impl_type
                )
329 330

            dist_impl = container.impls[op_dist_attr.impl_idx]
331 332 333
            dist_op_cost = dist_impl.calc_cost(
                op.attr('op_role'), dist_op, dist_context, cluster
            )
334 335 336 337 338 339 340 341 342
            self.assertTrue(dist_op_cost)

    def test_dist_op_cost_part4(self):
        def make_program():
            main_program = paddle.static.Program()
            start_program = paddle.static.Program()
            with paddle.static.program_guard(main_program, start_program):
                x = paddle.static.data(name='x', shape=[4], dtype='float32')
                x.stop_gradient = True
343 344 345
                label = paddle.static.data(
                    name="label", shape=[8, 1], dtype='float32'
                )
346
                label.stop_gradient = True
347 348 349 350 351 352 353 354
                auto.shard_tensor(
                    x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x"]
                )
                auto.shard_tensor(
                    label,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
355 356
                # embedding
                tmp = paddle.fluid.layers.fill_constant_batch_size_like(
357 358
                    input=x, shape=[4], value=1, dtype='int32'
                )
359 360 361 362 363 364
                embedding = paddle.nn.Embedding(10, 8)
                out = embedding(tmp)
                # row parallel embedding
                for op in main_program.global_block().ops:
                    if op.type == "lookup_table_v2":
                        W = main_program.global_block().vars[op.input("W")[0]]
365
                        auto.shard_tensor(
366 367 368 369
                            W,
                            auto.ProcessMesh([0, 1], dim_names=["x"]),
                            ["x", None],
                        )
370
                out = paddle.transpose(out, [1, 0])  # [8, 2] [-1, 0]
371 372

                # mul
373
                param1 = paddle.create_parameter(
374 375 376 377 378 379 380
                    [4, 8], paddle.float32
                )  # [2, 8] [0, -1]
                auto.shard_tensor(
                    param1,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
381
                param2 = paddle.create_parameter(
382 383 384 385 386 387 388
                    [8, 8], paddle.float32
                )  # [8, 4] [-1, 0]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, "x"],
                )
389

Z
zqw_1997 已提交
390
                out1 = paddle.matmul(out, param1)  # [8, 8] [-1, -1]
391
                tmp_param = paddle.create_parameter(
392 393 394 395 396 397 398
                    [8, 8], paddle.float32
                )  # [8, 8] [-1, -1]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, None],
                )
399

Z
zqw_1997 已提交
400 401
                tmp_out = paddle.matmul(out1, tmp_param)
                out2 = paddle.matmul(tmp_out, param2)  # [8, 4] [-1, 0]
402

403
                out8 = paddle.transpose(out2, [1, 0])  # [4, 8] [0, -1]
404 405 406 407

                # reshape
                out9 = paddle.reshape(out8, [8, 2, 4])  # [4, 2, 4] [0, -1, -1]
                tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
408 409 410
                out10 = paddle.reshape(
                    tmp_reshape_out, [8, 8]
                )  # [4, 8] [0, -1]
411 412 413 414 415

                # softmax
                softmax = paddle.nn.Softmax()
                out11 = softmax(out10)
                error_cost = paddle.nn.functional.square_error_cost(
416 417
                    out11, label
                )
418 419
                loss = paddle.mean(error_cost)
            return main_program, start_program, loss
C
caozhou 已提交
420 421 422 423 424 425 426 427

        main_program, dist_context = parallelizer(make_program, 0)
        ops = main_program.global_block().ops
        cluster = Cluster()
        cluster.gen_default_config_cluster(device_count=2)
        for idx, op in enumerate(ops):
            dist_op = dist_context.get_dist_op_for_program(op)
            op_dist_attr = dist_op.dist_attr
428
            processes = op_dist_attr.process_mesh.process_ids
C
caozhou 已提交
429 430
            if is_elementwise_op(op.type):
                container = get_distributed_operator_impl_container(
431 432
                    "elementwise"
                )
C
caozhou 已提交
433 434
            else:
                container = get_distributed_operator_impl_container(
435 436
                    op_dist_attr.impl_type
                )
C
caozhou 已提交
437 438

            dist_impl = container.impls[op_dist_attr.impl_idx]
439 440 441
            dist_op_cost = dist_impl.calc_cost(
                op.attr('op_role'), dist_op, dist_context, cluster
            )
C
caozhou 已提交
442 443
            self.assertTrue(dist_op_cost)

444 445 446

if __name__ == "__main__":
    unittest.main()