test_dist_op_cost.py 17.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
16
import unittest
17 18 19

import paddle
from paddle.distributed.auto_parallel.cluster import Cluster
20 21 22 23
from paddle.distributed.auto_parallel.operators.common import (
    get_distributed_operator_impl_container,
    is_elementwise_op,
)
24
from paddle.distributed.fleet import auto
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward

paddle.enable_static()


def parallelizer(program_func, rank):
    from paddle.distributed.auto_parallel.completion import Completer
    from paddle.distributed.auto_parallel.dist_context import DistributedContext

    main_program, startup_program, loss = program_func()

    # complete forward
    dist_context = DistributedContext()
    completer = Completer(dist_context)
    completer.complete_forward_annotation(main_program)
    dist_context.block_state.parse_forward_blocks(main_program)

    # generate backward and complete backward
    with paddle.static.program_guard(main_program, startup_program):
        params_grads = append_backward(
46 47
            loss, None, None, None, distop_context=dist_context.dist_op_context
        )
48 49 50
    completer.complete_backward_annotation(main_program)
    dist_context.block_state.parse_backward_blocks(main_program)

C
caozhou 已提交
51
    optimizer = paddle.optimizer.Adam(learning_rate=0.001)
52 53 54 55 56 57 58 59 60 61
    # generate opt and complete opt
    with program_guard(main_program, startup_program):
        optimize_ops = copy.deepcopy(optimizer).apply_gradients(params_grads)

    completer.complete_update_annotation(main_program)

    return main_program, dist_context


class TestDistOpCost(unittest.TestCase):
C
caozhou 已提交
62
    def test_dist_op_cost_part1(self):
63 64 65 66 67 68
        def make_program():
            main_program = paddle.static.Program()
            start_program = paddle.static.Program()
            with paddle.static.program_guard(main_program, start_program):
                x = paddle.static.data(name='x', shape=[4, 8], dtype='float32')
                x.stop_gradient = True
69 70 71
                label = paddle.static.data(
                    name="label", shape=[4, 1], dtype='float32'
                )
72
                label.stop_gradient = True
73 74 75
                auto.shard_tensor(
                    x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None]
                )
76
                tmp = paddle.fluid.layers.fill_constant_batch_size_like(
77 78
                    input=x, shape=[2, 8], value=1, dtype='float32'
                )
79
                weight_attr = paddle.ParamAttr()
C
caozhou 已提交
80
                linear = paddle.nn.Linear(8, 1, weight_attr=weight_attr)
81 82 83 84 85 86 87 88 89 90 91 92 93
                linear_out = linear(x)
                gelu_out = paddle.nn.functional.gelu(linear_out)
                # default op with dp
                tmp = paddle.static.nn.layer_norm(gelu_out)
                error_cost = paddle.nn.functional.square_error_cost(tmp, label)
                loss = paddle.mean(error_cost)
            return main_program, start_program, loss

        main_program, dist_context = parallelizer(make_program, 0)
        ops = main_program.global_block().ops
        cluster = Cluster()
        cluster.gen_default_config_cluster(device_count=2)
        for idx, op in enumerate(ops):
94 95 96 97 98
            if (
                op.type != "matmul_v2"
                and op.type != "matmul_v2_grad"
                and op.type != "sgd"
            ):
99 100
                dist_op = dist_context.get_dist_op_for_program(op)
                op_dist_attr = dist_op.dist_attr
101
                processes = op_dist_attr.process_mesh.process_ids
102 103
                if is_elementwise_op(op.type):
                    container = get_distributed_operator_impl_container(
104 105
                        "elementwise"
                    )
106 107
                else:
                    container = get_distributed_operator_impl_container(
108 109
                        op_dist_attr.impl_type
                    )
110 111

                dist_impl = container.impls[op_dist_attr.impl_idx]
112 113 114
                dist_op_cost = dist_impl.calc_cost(
                    op.attr('op_role'), dist_op, dist_context, cluster
                )
115 116
                self.assertTrue(dist_op_cost)

C
caozhou 已提交
117 118 119 120 121 122 123
    def test_dist_op_cost_part2(self):
        def make_program():
            main_program = paddle.static.Program()
            start_program = paddle.static.Program()
            with paddle.static.program_guard(main_program, start_program):
                x = paddle.static.data(name='x', shape=[4], dtype='float32')
                x.stop_gradient = True
124 125 126
                label = paddle.static.data(
                    name="label", shape=[8, 1], dtype='float32'
                )
C
caozhou 已提交
127
                label.stop_gradient = True
128 129 130 131 132 133 134 135 136
                auto.shard_tensor(
                    x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x"]
                )

                auto.shard_tensor(
                    label,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
C
caozhou 已提交
137 138
                # embedding
                tmp = paddle.fluid.layers.fill_constant_batch_size_like(
139 140
                    input=x, shape=[4], value=1, dtype='int32'
                )
C
caozhou 已提交
141 142 143 144 145 146
                embedding = paddle.nn.Embedding(10, 8)
                out = embedding(tmp)
                # row parallel embedding
                for op in main_program.global_block().ops:
                    if op.type == "lookup_table_v2":
                        W = main_program.global_block().vars[op.input("W")[0]]
147
                        auto.shard_tensor(
148 149 150 151
                            W,
                            auto.ProcessMesh([0, 1], dim_names=["x"]),
                            ["x", None],
                        )
152
                out = paddle.transpose(out, [1, 0])  # [8, 2] [-1, 0]
C
caozhou 已提交
153 154

                # matmul
155
                param1 = paddle.create_parameter(
156 157 158 159 160 161 162
                    [4, 8], paddle.float32
                )  # [2, 8] [0, -1]
                auto.shard_tensor(
                    param1,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
163
                param2 = paddle.create_parameter(
164 165 166 167 168 169 170
                    [8, 8], paddle.float32
                )  # [8, 4] [-1, 0]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, "x"],
                )
K
kangguangli 已提交
171
                out1 = paddle.matmul(out, param1)  # [8, 8] [-1, -1]
172
                tmp_param = paddle.create_parameter(
173 174 175 176 177 178 179
                    [8, 8], paddle.float32
                )  # [8, 8] [-1, -1]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, None],
                )
K
kangguangli 已提交
180
                tmp_out = paddle.matmul(out1, tmp_param)
181
                tmp_out = paddle.scale(tmp_out, 0.5)
K
kangguangli 已提交
182
                out2 = paddle.matmul(tmp_out, param2)  # [8, 4] [-1, 0]
C
caozhou 已提交
183

184
                out8 = paddle.transpose(out2, [1, 0])  # [4, 8] [0, -1]
C
caozhou 已提交
185 186 187 188

                # reshape
                out9 = paddle.reshape(out8, [8, 2, 4])  # [4, 2, 4] [0, -1, -1]
                tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
189 190 191
                out10 = paddle.reshape(
                    tmp_reshape_out, [8, 8]
                )  # [4, 8] [0, -1]
C
caozhou 已提交
192 193 194 195 196

                # softmax
                softmax = paddle.nn.Softmax()
                out11 = softmax(out10)
                error_cost = paddle.nn.functional.square_error_cost(
197 198
                    out11, label
                )
C
caozhou 已提交
199 200
                loss = paddle.mean(error_cost)
            return main_program, start_program, loss
201 202 203 204 205 206 207 208

        main_program, dist_context = parallelizer(make_program, 0)
        ops = main_program.global_block().ops
        cluster = Cluster()
        cluster.gen_default_config_cluster(device_count=2)
        for idx, op in enumerate(ops):
            dist_op = dist_context.get_dist_op_for_program(op)
            op_dist_attr = dist_op.dist_attr
209
            processes = op_dist_attr.process_mesh.process_ids
210 211
            if is_elementwise_op(op.type):
                container = get_distributed_operator_impl_container(
212 213
                    "elementwise"
                )
214 215
            else:
                container = get_distributed_operator_impl_container(
216 217
                    op_dist_attr.impl_type
                )
218 219

            dist_impl = container.impls[op_dist_attr.impl_idx]
220 221 222
            dist_op_cost = dist_impl.calc_cost(
                op.attr('op_role'), dist_op, dist_context, cluster
            )
223 224 225 226 227 228 229 230 231
            self.assertTrue(dist_op_cost)

    def test_dist_op_cost_part3(self):
        def make_program():
            main_program = paddle.static.Program()
            start_program = paddle.static.Program()
            with paddle.static.program_guard(main_program, start_program):
                x = paddle.static.data(name='x', shape=[4], dtype='float32')
                x.stop_gradient = True
232 233 234
                label = paddle.static.data(
                    name="label", shape=[8, 1], dtype='float32'
                )
235
                label.stop_gradient = True
236 237 238 239 240 241 242 243 244
                auto.shard_tensor(
                    x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x"]
                )

                auto.shard_tensor(
                    label,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
245 246
                # embedding
                tmp = paddle.fluid.layers.fill_constant_batch_size_like(
247 248
                    input=x, shape=[4], value=1, dtype='int32'
                )
249 250 251 252 253 254
                embedding = paddle.nn.Embedding(10, 8)
                out = embedding(tmp)
                # row parallel embedding
                for op in main_program.global_block().ops:
                    if op.type == "lookup_table_v2":
                        W = main_program.global_block().vars[op.input("W")[0]]
255
                        auto.shard_tensor(
256 257 258 259
                            W,
                            auto.ProcessMesh([0, 1], dim_names=["x"]),
                            ["x", None],
                        )
260
                out = paddle.transpose(out, [1, 0])  # [8, 2] [-1, 0]
261 262

                # matmul_v2
263
                param1 = paddle.create_parameter(
264 265 266 267 268 269 270
                    [4, 8], paddle.float32
                )  # [2, 8] [0, -1]
                auto.shard_tensor(
                    param1,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
271
                param2 = paddle.create_parameter(
272 273 274 275 276 277 278
                    [8, 8], paddle.float32
                )  # [8, 4] [-1, 0]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, "x"],
                )
279
                out1 = paddle.matmul(out, param1)  # [8, 8] [-1, -1]
280
                tmp_param = paddle.create_parameter(
281 282 283 284 285 286 287
                    [8, 8], paddle.float32
                )  # [8, 8] [-1, -1]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, None],
                )
288

289
                tmp_out = paddle.matmul(out1, tmp_param)
290
                tmp_out = paddle.scale(tmp_out, 0.5)
291 292
                out2 = paddle.matmul(tmp_out, param2)  # [8, 4] [-1, 0]

293
                out8 = paddle.transpose(out2, [1, 0])  # [4, 8] [0, -1]
294 295 296 297

                # reshape
                out9 = paddle.reshape(out8, [8, 2, 4])  # [4, 2, 4] [0, -1, -1]
                tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
298 299 300
                out10 = paddle.reshape(
                    tmp_reshape_out, [8, 8]
                )  # [4, 8] [0, -1]
301 302 303 304 305

                # softmax
                softmax = paddle.nn.Softmax()
                out11 = softmax(out10)
                error_cost = paddle.nn.functional.square_error_cost(
306 307
                    out11, label
                )
308 309 310 311 312 313 314 315 316 317
                loss = paddle.mean(error_cost)
            return main_program, start_program, loss

        main_program, dist_context = parallelizer(make_program, 0)
        ops = main_program.global_block().ops
        cluster = Cluster()
        cluster.gen_default_config_cluster(device_count=2)
        for idx, op in enumerate(ops):
            dist_op = dist_context.get_dist_op_for_program(op)
            op_dist_attr = dist_op.dist_attr
318
            processes = op_dist_attr.process_mesh.process_ids
319 320
            if is_elementwise_op(op.type):
                container = get_distributed_operator_impl_container(
321 322
                    "elementwise"
                )
323 324
            else:
                container = get_distributed_operator_impl_container(
325 326
                    op_dist_attr.impl_type
                )
327 328

            dist_impl = container.impls[op_dist_attr.impl_idx]
329 330 331
            dist_op_cost = dist_impl.calc_cost(
                op.attr('op_role'), dist_op, dist_context, cluster
            )
332 333 334 335 336 337 338 339 340
            self.assertTrue(dist_op_cost)

    def test_dist_op_cost_part4(self):
        def make_program():
            main_program = paddle.static.Program()
            start_program = paddle.static.Program()
            with paddle.static.program_guard(main_program, start_program):
                x = paddle.static.data(name='x', shape=[4], dtype='float32')
                x.stop_gradient = True
341 342 343
                label = paddle.static.data(
                    name="label", shape=[8, 1], dtype='float32'
                )
344
                label.stop_gradient = True
345 346 347 348 349 350 351 352
                auto.shard_tensor(
                    x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x"]
                )
                auto.shard_tensor(
                    label,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
353 354
                # embedding
                tmp = paddle.fluid.layers.fill_constant_batch_size_like(
355 356
                    input=x, shape=[4], value=1, dtype='int32'
                )
357 358 359 360 361 362
                embedding = paddle.nn.Embedding(10, 8)
                out = embedding(tmp)
                # row parallel embedding
                for op in main_program.global_block().ops:
                    if op.type == "lookup_table_v2":
                        W = main_program.global_block().vars[op.input("W")[0]]
363
                        auto.shard_tensor(
364 365 366 367
                            W,
                            auto.ProcessMesh([0, 1], dim_names=["x"]),
                            ["x", None],
                        )
368
                out = paddle.transpose(out, [1, 0])  # [8, 2] [-1, 0]
369 370

                # mul
371
                param1 = paddle.create_parameter(
372 373 374 375 376 377 378
                    [4, 8], paddle.float32
                )  # [2, 8] [0, -1]
                auto.shard_tensor(
                    param1,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    ["x", None],
                )
379
                param2 = paddle.create_parameter(
380 381 382 383 384 385 386
                    [8, 8], paddle.float32
                )  # [8, 4] [-1, 0]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, "x"],
                )
387

Z
zqw_1997 已提交
388
                out1 = paddle.matmul(out, param1)  # [8, 8] [-1, -1]
389
                tmp_param = paddle.create_parameter(
390 391 392 393 394 395 396
                    [8, 8], paddle.float32
                )  # [8, 8] [-1, -1]
                auto.shard_tensor(
                    param2,
                    auto.ProcessMesh([0, 1], dim_names=["x"]),
                    [None, None],
                )
397

Z
zqw_1997 已提交
398 399
                tmp_out = paddle.matmul(out1, tmp_param)
                out2 = paddle.matmul(tmp_out, param2)  # [8, 4] [-1, 0]
400

401
                out8 = paddle.transpose(out2, [1, 0])  # [4, 8] [0, -1]
402 403 404 405

                # reshape
                out9 = paddle.reshape(out8, [8, 2, 4])  # [4, 2, 4] [0, -1, -1]
                tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
406 407 408
                out10 = paddle.reshape(
                    tmp_reshape_out, [8, 8]
                )  # [4, 8] [0, -1]
409 410 411 412 413

                # softmax
                softmax = paddle.nn.Softmax()
                out11 = softmax(out10)
                error_cost = paddle.nn.functional.square_error_cost(
414 415
                    out11, label
                )
416 417
                loss = paddle.mean(error_cost)
            return main_program, start_program, loss
C
caozhou 已提交
418 419 420 421 422 423 424 425

        main_program, dist_context = parallelizer(make_program, 0)
        ops = main_program.global_block().ops
        cluster = Cluster()
        cluster.gen_default_config_cluster(device_count=2)
        for idx, op in enumerate(ops):
            dist_op = dist_context.get_dist_op_for_program(op)
            op_dist_attr = dist_op.dist_attr
426
            processes = op_dist_attr.process_mesh.process_ids
C
caozhou 已提交
427 428
            if is_elementwise_op(op.type):
                container = get_distributed_operator_impl_container(
429 430
                    "elementwise"
                )
C
caozhou 已提交
431 432
            else:
                container = get_distributed_operator_impl_container(
433 434
                    op_dist_attr.impl_type
                )
C
caozhou 已提交
435 436

            dist_impl = container.impls[op_dist_attr.impl_idx]
437 438 439
            dist_op_cost = dist_impl.calc_cost(
                op.attr('op_role'), dist_op, dist_context, cluster
            )
C
caozhou 已提交
440 441
            self.assertTrue(dist_op_cost)

442 443 444

if __name__ == "__main__":
    unittest.main()