util_factory.py 26.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fleet Utils."""
"""distributed operations"""
"""basic collective operations in python"""
"""remote file system"""

W
wangxiaoning 已提交
19
import paddle
20
from ..utils.fs import FS
21
from paddle.fluid.proto import framework_pb2
W
wangxiaoning 已提交
22
from paddle.static import Program
23 24
from paddle.fluid import debugger
from google.protobuf import text_format
W
wangxiaoning 已提交
25
import paddle.framework as framework
26 27 28 29 30
from collections import OrderedDict
from paddle.fluid import core
import subprocess
import os
import numpy as np
31 32

__all__ = []
33

34

35
class UtilFactory:
36
    def _create_util(self, context=None):
37
        util = UtilBase()
38 39 40 41
        if context is not None and "valid_strategy" in context:
            util._set_strategy(context["valid_strategy"])
        if context is not None and "role_maker" in context:
            util._set_role_maker(context["role_maker"])
42 43 44
        return util


45
class UtilBase:
46 47 48 49 50 51 52 53 54
    def __init__(self):
        self.role_maker = None
        self.dist_strategy = None

    def _set_strategy(self, dist_strategy):
        self.dist_strategy = dist_strategy

    def _set_role_maker(self, role_maker):
        self.role_maker = role_maker
55

56
    def _set_file_system(self, fs_client):
57
        assert isinstance(
58 59
            fs_client, FS
        ), "fs_client must be the instance of paddle.distributed.fleet.utils.FS"
60 61
        self.fs_client = fs_client

62
    def all_reduce(self, input, mode="sum", comm_world="worker"):
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        """
        All reduce `input` between specified collection. This is a distributed API.

        Args:
            input (list|numpy.array): The input variable to do all_reduce between specified collection.
            mode (str): "sum" or "min" or "max".
            comm_world (str, optional): Collection used to execute all_reduce operation. Supported collections incude `worker` , `server` and `all` . The default is `worker` .

        Returns:
            output(Numpy.array|None): A numpy array with the same shape as the `input` .

        Examples:
            .. code-block:: python

                # Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
                import paddle.distributed.fleet as fleet
                from paddle.distributed.fleet import PaddleCloudRoleMaker
                import sys
                import numpy as np
82 83 84
                import os

                os.environ["PADDLE_WITH_GLOO"] = "2"
85 86 87 88 89 90 91 92 93 94

                def train():
                    role = PaddleCloudRoleMaker(
                        is_collective=False,
                        init_gloo=True,
                        path="./tmp_gloo")
                    fleet.init(role)

                    if fleet.is_server():
                        input = [1, 2]
95
                        output = fleet.util.all_reduce(input, "sum", "server")
96 97 98 99
                        print(output)
                        # [2, 4]
                    elif fleet.is_worker():
                        input = np.array([3, 4])
100
                        output = fleet.util.all_reduce(input, "sum", "worker")
101 102
                        print(output)
                        # [6, 8]
103
                    output = fleet.util.all_reduce(input, "sum", "all")
104 105 106 107 108
                    print(output)
                    # [8, 12]
                if __name__ == "__main__":
                    train()
        """
109
        return self.role_maker._all_reduce(input, mode, comm_world)
110 111

    def barrier(self, comm_world="worker"):
112 113 114 115 116 117 118 119 120
        """
        Barrier between specified collection.

        Args:
            comm_world (str, optional): Collection used to execute barrier operation. Supported collections incude `worker` , `server` and `all` . The default is `worker` .

        Examples:

            .. code-block:: python
121

122 123 124 125 126
                # Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .

                import paddle.distributed.fleet as fleet
                from paddle.distributed.fleet import PaddleCloudRoleMaker
                import sys
127 128 129
                import os

                os.environ["PADDLE_WITH_GLOO"] = "2"
130 131 132 133 134 135 136 137 138

                def train():
                    role = PaddleCloudRoleMaker(
                        is_collective=False,
                        init_gloo=True,
                        path="./tmp_gloo")
                    fleet.init(role)

                    if fleet.is_server():
139
                        fleet.util.barrier("server")
140 141
                        print("all server arrive here")
                    elif fleet.is_worker():
142
                        fleet.util.barrier("worker")
143
                        print("all server arrive here")
144
                    fleet.util.barrier("all")
145 146 147 148 149
                    print("all servers and workers arrive here")

                if __name__ == "__main__":
                    train()
        """
150
        self.role_maker._barrier(comm_world)
151 152

    def all_gather(self, input, comm_world="worker"):
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        """
        All gather `input` between specified collection.

        Args:
            input (Int|Float): The input variable to do all_gather between specified collection.
            comm_world (str, optional): Collection used to execute all_reduce operation. Supported collections incude `worker` , `server` and `all` . The default is `worker` .

        Returns:
            output (List): A list of gathered values.

        Examples:

            .. code-block:: python

                # Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
                import paddle.distributed.fleet as fleet
                from paddle.distributed.fleet import PaddleCloudRoleMaker
                import sys
171 172 173
                import os

                os.environ["PADDLE_WITH_GLOO"] = "2"
174 175 176 177 178 179 180 181 182 183

                def train():
                    role = PaddleCloudRoleMaker(
                        is_collective=False,
                        init_gloo=True,
                        path="./tmp_gloo")
                    fleet.init(role)

                    if fleet.is_server():
                        input = fleet.server_index()
184
                        output = fleet.util.all_gather(input, "server")
185 186 187 188
                        print(output)
                        # output = [0, 1]
                    elif fleet.is_worker():
                        input = fleet.worker_index()
189
                        output = fleet.util.all_gather(input, "worker")
190 191
                        # output = [0, 1]
                        print(output)
192
                    output = fleet.util.all_gather(input, "all")
193 194 195 196 197 198
                    print(output)
                    # output = [0, 1, 0, 1]

                if __name__ == "__main__":
                    train()
        """
199 200

        return self.role_maker._all_gather(input, comm_world)
201

202
    def _broadcast(self):
203 204
        pass

205
    def _scatter(self):
206 207
        pass

208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
    def get_heter_file_shard(self, files):
        if not isinstance(files, list):
            raise TypeError("files should be a list of file need to be read.")
        trainers = self.role_maker._worker_num()
        trainer_id = self.role_maker._worker_index() - trainers
        remainder = len(files) % trainers
        blocksize = int(len(files) / trainers)

        blocks = [blocksize] * trainers
        for i in range(remainder):
            blocks[i] += 1

        trainer_files = [[]] * trainers
        begin = 0
        for i in range(trainers):
223
            trainer_files[i] = files[begin : begin + blocks[i]]
224 225 226 227
            begin += blocks[i]

        return trainer_files[trainer_id]

228
    def get_file_shard(self, files):
229
        """
230 231 232 233 234 235 236 237
        Split files before distributed training, and return filelist assigned to the current trainer.

        .. code-block:: text

            example 1: files is [a, b, c ,d, e]  and trainer_num = 2, then trainer
                    0 gets [a, b, c] and trainer 1 gets [d, e].
            example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
                    [a], trainer 1 gets [b],  trainer 2 gets []
238

239
        Args:
240
            files(list): File list need to be read.
241

242
        Returns:
243 244 245 246 247 248
            List: Files belong to this worker.

        Examples:

            .. code-block:: python

249 250
                import paddle.distributed.fleet as fleet
                from paddle.distributed.fleet import UserDefinedRoleMaker
251

252
                role = UserDefinedRoleMaker(
253 254 255
                    is_collective=False,
                    init_gloo=False,
                    current_id=0,
256
                    role=fleet.Role.WORKER,
257 258
                    worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
                    server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
259 260 261 262
                fleet.init(role)

                files = fleet.util.get_file_shard(["file1", "file2", "file3"])
                print(files)
263
                # files = ["file1", "file2"]
264 265 266
        """
        if not isinstance(files, list):
            raise TypeError("files should be a list of file need to be read.")
267

268 269
        trainer_id = self.role_maker._worker_index()
        trainers = self.role_maker._worker_num()
270

271 272
        remainder = len(files) % trainers
        blocksize = int(len(files) / trainers)
273

274 275 276
        blocks = [blocksize] * trainers
        for i in range(remainder):
            blocks[i] += 1
277

278 279 280
        trainer_files = [[]] * trainers
        begin = 0
        for i in range(trainers):
281
            trainer_files[i] = files[begin : begin + blocks[i]]
282 283 284 285 286
            begin += blocks[i]

        return trainer_files[trainer_id]

    def print_on_rank(self, message, rank_id):
287
        """
288
        Woker of rank `rank_id` print some message.
289 290 291 292 293 294 295 296 297

        Args:
            message(str): Log to be printed.
            rank_id(int): trainer id.

        Examples:

            .. code-block:: python

298 299
                import paddle.distributed.fleet as fleet
                from paddle.distributed.fleet import UserDefinedRoleMaker
300

301
                role = UserDefinedRoleMaker(
302 303 304
                    is_collective=False,
                    init_gloo=False,
                    current_id=0,
305
                    role=fleet.Role.WORKER,
306 307
                    worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
                    server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
308 309 310
                fleet.init(role)

                fleet.util.print_on_rank("I'm worker 0", 0)
311
        """
312
        if self.role_maker._worker_index() != rank_id:
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
            return
        print(message)

    def _save_program(self, program, model_filename='__model__', is_text=False):
        if is_text:
            with open(model_filename, "w") as f:
                f.write(str(program))
        else:
            with open(model_filename, "wb") as f:
                f.write(program.desc.serialize_to_string())

    def _load_program(self, path, is_text):
        def load_program_binary(path):
            """load program from binary string file"""
            with open(path, "rb") as f:
                program_desc_str = f.read()
            return Program.parse_from_string(program_desc_str)

        def load_program_text(path):
            """load program from human-readable text file"""
            with open(path, "r") as f:
                program_desc_text = f.read()

            prog_desc = framework_pb2.ProgramDesc()
            text_format.Merge(program_desc_text, prog_desc)
            return Program.parse_from_string(prog_desc.SerializeToString())

        if is_text:
            return load_program_text(path)
        else:
            return load_program_binary(path)

    def _program_type_trans(self, prog_dir, prog_fn, is_text):
        prog = self._load_program(os.path.join(prog_dir, prog_fn), is_text)
        prog_out_fn = prog_fn + ".bin" if is_text else prog_fn + ".pbtxt"
348 349 350
        self._save_program(
            prog, os.path.join(prog_dir, prog_out_fn), 1 - is_text
        )
351 352 353 354 355 356 357 358
        return prog_out_fn

    def _visualize_graphviz(self, program, output_dir, output_filename):
        block = program.global_block()
        dot_path = os.path.join(output_dir, output_filename + '.dot')
        pdf_path = os.path.join(output_dir, output_filename + '.pdf')
        debugger.draw_block_graphviz(block, path=dot_path)
        cmd = ["dot", "-Tpdf", dot_path, "-o", pdf_path]
359 360 361 362 363 364
        p = subprocess.Popen(
            cmd,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
365 366 367
        p.wait()

    def _proto_check(self, config):
368 369 370 371 372 373
        train_prog = self._load_program(
            config.train_prog_path, config.is_text_train_program
        )
        pruned_prog = self._load_program(
            config.pruned_prog_path, config.is_text_pruned_program
        )
374 375 376

        is_match = True

377 378 379
        pruned_vars = [
            (v.name, v)
            for v in pruned_prog.list_vars()
W
wangxiaoning 已提交
380
            if paddle.static.io.is_persistable(v)
381
        ]
382 383 384 385 386 387
        pruned_vars = OrderedDict(pruned_vars)
        pruned_vars_name = [name for name in pruned_vars]
        print("persistable vars in pruned program: {}".format(pruned_vars_name))

        # feed and fetch op is added in pruned program when pruning, not need to be found in train program
        feed_fetch_type_list = [
388 389
            core.VarDesc.VarType.FEED_MINIBATCH,
            core.VarDesc.VarType.FETCH_LIST,
390 391 392 393 394 395 396 397 398 399 400 401
        ]

        for var_name in pruned_vars:
            var = pruned_vars[var_name]
            # feed and fetch op is added in pruned program when pruning, not need to be found in train program
            if var.type in feed_fetch_type_list:
                break
            try:
                train_prog_var = train_prog.global_block().var(var_name)
            except ValueError as e:
                print(
                    "Not find variable '%s' in train program. please check pruning."
402 403
                    % var_name
                )
404 405
                is_match = False
                continue
406 407 408 409
            if (
                var.shape != train_prog_var.shape
                or var.dtype != train_prog_var.dtype
            ):
410
                print(
411 412 413 414 415 416 417 418
                    "variable: {} not match. in pruned program shape: {} dtype:{}, in train program shape: {} dtype: {}".format(
                        var_name,
                        var.shape,
                        var.dtype,
                        train_prog_var.shape,
                        train_prog_var.dtype,
                    )
                )
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
                is_match = False
        return is_match

    def _params_check(self, config):
        def feed_gen(batch_size, feeded_vars_dims, feeded_vars_filelist):
            def reader(batch_size, fn, dim):
                data = []
                if isinstance(dim, list) or isinstance(dim, tuple):
                    shape = list(dim)
                    _temp = 1
                    for x in dim:
                        _temp = _temp * x
                    dim = _temp
                else:
                    shape = [dim]

                shape = [batch_size] + shape
                dim = dim * batch_size

                for line in open(fn, 'r'):
                    fields = line.strip().split(' ')
                    fields = [float(d) for d in fields]
                    while len(fields) >= dim:
                        tmp = fields[:dim]
                        fields = fields[dim:]
                        data.append(np.array(tmp).reshape(shape))
                return data

            batch_feed = []
            for i, fn in enumerate(feeded_vars_filelist):
                batch_feed.append(reader(batch_size, fn, feeded_vars_dims[i]))
            return batch_feed

        prog = self._load_program(
            os.path.join(config.dump_model_dir, config.dump_program_filename),
454 455
            config.is_text_dump_program,
        )
456 457
        if config.is_text_dump_program:
            model_filename = self._program_type_trans(
458 459 460 461
                config.dump_model_dir,
                config.dump_program_filename,
                config.is_text_dump_program,
            )
462 463

        saved_params = [
W
wangxiaoning 已提交
464
            v for v in prog.list_vars() if paddle.static.io.is_persistable(v)
465
        ]
466 467 468 469 470
        print(
            "persistable vars in dump program: {}".format(
                [v.name for v in saved_params]
            )
        )
471 472 473 474

        def check_not_expected_ops(prog, not_expected_op_types):
            op_types_set = set()
            for op in prog.global_block().ops:
475 476 477 478
                if (
                    op.type in not_expected_op_types
                    and op.type not in op_types_set
                ):
479 480 481 482 483 484
                    op_types_set.add(op.type)
            return op_types_set

        not_expected_op_types = check_not_expected_ops(prog, ["lookup_table"])
        if len(not_expected_op_types) > 0:
            print(
485 486 487 488
                "find op type '{}' in program, please check if your program is pruned correctly !".format(
                    list(not_expected_op_types)
                )
            )
489 490
            return False

W
wangxiaoning 已提交
491 492 493 494
        place = framework.CPUPlace()
        exe = paddle.static.Executor(place)
        scope = paddle.static.Scope()
        with paddle.static.scope_guard(scope):
495 496 497 498
            (
                inference_program,
                feed_target_names,
                fetch_targets,
W
wangxiaoning 已提交
499
            ) = paddle.fluid.io.load_inference_model(
500 501 502 503 504
                config.dump_model_dir,
                exe,
                model_filename=model_filename,
                params_filename=config.save_params_filename,
            )
505 506 507 508 509 510 511

            # check program vars and saved vars shape
            orig_para_shape = {
                each_var.name: tuple(each_var.desc.shape())
                for each_var in saved_params
            }
            for each_var in saved_params:
W
wangxiaoning 已提交
512
                var_temp = paddle.static.global_scope().find_var(each_var.name)
513 514 515
                assert var_temp is not None, (
                    "can't not find var: " + each_var.name
                )
516
                new_shape = (np.array(var_temp.get_tensor())).shape
517 518 519
                assert each_var.name in orig_para_shape, (
                    each_var.name + "MUST in var list"
                )
520 521 522 523
                orig_shape = orig_para_shape.get(each_var.name)
                if new_shape != orig_shape:
                    raise RuntimeError(
                        "Shape not matching: the Program requires a parameter with a shape of ({}), "
524 525 526 527
                        "while the loaded parameter (namely [ {} ]) has a shape of  ({}).".format(
                            orig_shape, each_var.name, new_shape
                        )
                    )
528 529 530 531 532 533 534 535 536 537 538

            # check feed/fetch vars in program and config
            feed_config = config.feed_config
            fetch_config = config.fetch_config
            fetch_targets_names = [v.name for v in fetch_targets]
            if not feed_target_names:
                print("warning! no feed targets in program.")
            if not fetch_targets_names:
                print("warning! no fetch targets in program.")
            fetch_list = fetch_targets
            feed_name_list = feed_target_names
539 540 541 542
            if (
                feed_config.feeded_vars_names is not None
                and feed_target_names != feed_config.feeded_vars_names
            ):
543
                print(
544 545 546 547
                    "warning! feed vars in program and config are diff: feed in program: {}. feed in config {}.".format(
                        feed_target_names, feed_config.feeded_vars_names
                    )
                )
548 549 550 551 552 553 554 555 556 557
                feed_name_list = feed_config.feeded_vars_names
                # remove feed op in inference_program. new feed op will be added in exe.run
                global_block = inference_program.global_block()
                need_to_remove_op_index = []
                for i, op in enumerate(global_block.ops):
                    op.desc.set_is_target(False)
                    if op.type == "feed":  # only remove feed op here
                        need_to_remove_op_index.append(i)
                for index in need_to_remove_op_index[::-1]:
                    global_block._remove_op(index)
558 559 560 561
            if (
                fetch_config.fetch_vars_names is not None
                and fetch_targets_names != fetch_config.fetch_vars_names
            ):
562
                print(
563 564 565 566
                    "warning! fetch vars in program and config are diff: fetch in program: {}. fetch in config {}.".format(
                        fetch_targets_names, fetch_config.fetch_vars_names
                    )
                )
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
                fetch_list = [
                    inference_program.global_block().var(i)
                    for i in fetch_config.fetch_vars_names
                ]
                # remove fetch op in inference_program. new fetch op will be added in exe.run
                global_block = inference_program.global_block()
                need_to_remove_op_index = []
                for i, op in enumerate(global_block.ops):
                    op.desc.set_is_target(False)
                    if op.type == "fetch":  # only remove fetch op here
                        need_to_remove_op_index.append(i)
                for index in need_to_remove_op_index[::-1]:
                    global_block._remove_op(index)

            # if fetch_list have lod tensor
            return_numpy = all([v.lod_level == 0 for v in fetch_list])

            # try dump fetch_targets
            feed_tensors = []
586 587 588 589 590
            assert (
                len(feed_config.feeded_vars_names)
                == len(feed_config.feeded_vars_dims)
                == len(feed_config.feeded_vars_types)
            )
591 592 593
            # check program vars and feed tensor shape in config
            for i in range(len(feed_config.feeded_vars_names)):
                var = inference_program.global_block().var(
594 595 596 597 598 599
                    feed_config.feeded_vars_names[i]
                )
                if not isinstance(
                    feed_config.feeded_vars_dims[i], (list, tuple)
                ):
                    tensor_shape = (feed_config.feeded_vars_dims[i],)
600 601 602 603 604 605
                else:
                    tensor_shape = tuple(feed_config.feeded_vars_dims[i])
                feed_config.feeded_vars_dims[i] = tensor_shape
                var_shape = var.shape[1:]
                if tensor_shape != var_shape:
                    raise RuntimeError(
606 607 608 609 610 611
                        "feed variable '{}' shape not match. infer program  shape: {}. feed tensor shape: {}".format(
                            feed_config.feeded_vars_names[i],
                            var_shape,
                            tensor_shape,
                        )
                    )
612 613 614 615 616

            if not feed_config.feeded_vars_filelist:
                print("generate random feed vars.")
                for i in range(len(feed_config.feeded_vars_names)):
                    var = inference_program.global_block().var(
617 618
                        feed_config.feeded_vars_names[i]
                    )
619 620 621
                    # create fake feed tensor. if lod_level > 1, should create_lod_tensor()
                    if var.lod_level == 0:
                        feed_tensors.append(
622 623 624 625 626 627 628 629 630 631
                            np.array(
                                np.random.random(
                                    tuple(
                                        [config.batch_size]
                                        + list(feed_config.feeded_vars_dims[i])
                                    )
                                ),
                                dtype=feed_config.feeded_vars_types[i],
                            )
                        )
632
                    elif var.lod_level == 1:
633 634 635 636 637 638 639 640 641
                        t = np.array(
                            np.random.random(
                                tuple(
                                    [config.batch_size]
                                    + list(feed_config.feeded_vars_dims[i])
                                )
                            ),
                            dtype=feed_config.feeded_vars_types[i],
                        )
642
                        feed_tensors.append(
W
wangxiaoning 已提交
643
                            paddle.fluid.create_lod_tensor(
644 645 646
                                t, [[1] * config.batch_size], place
                            )
                        )
647 648 649 650
                    else:
                        raise RuntimeError(
                            "vars with lod_level >= 2 is not supported now in this infer program check tool."
                        )
651 652 653 654 655 656 657 658 659
                results = exe.run(
                    inference_program,
                    feed={
                        name: feed_tensors[i]
                        for i, name in enumerate(feed_name_list)
                    },
                    fetch_list=fetch_list,
                    return_numpy=return_numpy,
                )
660
            else:
661 662 663 664 665
                print(
                    "load feed vars from files: {}.".format(
                        feed_config.feeded_vars_filelist
                    )
                )
666 667
                feed_vars = [
                    inference_program.global_block().var(
668 669
                        feed_config.feeded_vars_names[i]
                    )
670 671
                    for i in range(len(feed_config.feeded_vars_names))
                ]
W
wangxiaoning 已提交
672 673 674
                feeder = paddle.fluid.DataFeeder(
                    feed_list=feed_vars, place=place
                )
675 676 677 678 679
                batch_feed = feed_gen(
                    config.batch_size,
                    feed_config.feeded_vars_dims,
                    feed_config.feeded_vars_filelist,
                )
680
                slots = [batch_feed]
681 682 683 684 685 686
                results = exe.run(
                    inference_program,
                    feed=feeder.feed(slots),
                    fetch_list=fetch_list,
                    return_numpy=return_numpy,
                )
687 688 689 690
            for i, v in enumerate(fetch_list):
                print("fetch_targets name: %s" % v.name)
                print("fetch_targets: {}".format(results[i]))
            return results