engine.py 23.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
from collections import defaultdict

import paddle
20
import paddle.utils as utils
21 22
import paddle.distributed.auto_parallel as auto

23
from paddle import fluid, static
24
from paddle.io import Dataset
25
from paddle.metric import Metric
26
from paddle.static import InputSpec
27
from paddle.fluid import core
28
from paddle.fluid import program_guard
29
from paddle.fluid.layers.utils import flatten
30
from paddle.fluid.executor import global_scope, _to_name_str
31
from paddle.fluid.backward import append_backward
32
from paddle.fluid.framework import Operator
33 34
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
35
from paddle.distributed import fleet
36
from paddle.distributed.utils import get_logger
37
from paddle.distributed.passes import new_pass, PassContext
38

39
# from .cluster import Cluster, get_default_cluster
40 41
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
42 43 44 45 46
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
47
from .process_group import new_process_group, get_all_process_groups, get_world_process_group
48
from .dist_context import DistributedContext, get_default_distributed_context
49 50 51


class Engine:
52

53 54 55 56 57 58
    def __init__(self,
                 model=None,
                 inputs_spec=None,
                 labels_spec=None,
                 cluster=None,
                 strategy=None):
59
        self.model = model
60 61
        self.inputs_spec = self._validate_spec(inputs_spec)
        self.labels_spec = self._validate_spec(labels_spec)
62
        self.cluster = cluster
63 64
        # if self.cluster is None:
        #     self.cluster = get_default_cluster()
65
        self.strategy = strategy
66 67
        if self.strategy is None:
            self.strategy = fleet.DistributedStrategy()
68

69
        self._executor = None
70 71 72 73 74 75
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
        self._logger = get_logger(logging.INFO)

        self._default_strategy = None
76 77
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
78
        self._orig_dist_context = get_default_distributed_context()
79
        self._dist_contexts = {}
80 81
        self._serial_main_progs = {}
        self._serial_startup_progs = {}
82 83 84 85
        self._dist_main_progs = defaultdict(dict)  # dist main programs
        self._dist_startup_progs = defaultdict(dict)  # dist startup programs
        self._feed_vars = {}
        self._fetch_vars = {}
86
        self._planners = {}
87 88 89 90

    def prepare(self,
                optimizer=None,
                loss=None,
91
                gradient_scale=True,
92 93
                metrics=None,
                all_ranks=False):
94 95 96
        if optimizer and not isinstance(
                optimizer,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
97 98 99 100
            raise TypeError(
                    "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
                        " or `paddle.fluid.optimizer.Optimizer`."
                )
101
        self._optimizer = optimizer
102 103 104 105 106 107

        if loss and not isinstance(loss,
                                   paddle.nn.Layer) and not callable(loss):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
            )
108
        self._loss = loss
109 110 111 112 113 114

        metrics = metrics or []
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(
                    metric.__class__.__name__)
115
        self._metrics = to_list(metrics)
116
        self._gradient_scale = gradient_scale
117 118 119 120 121 122 123 124 125

        self._planned_mode = None
        self._modes = ['train', 'eval', 'predict']
        self._build()

        # Do auto parallel process
        for mode in self._modes:
            # Do the planning process
            self._plan(mode)
126
        for mode in self._modes:
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
            # Do the parallel process
            self._parallel(mode, all_ranks)
            # Init comm and startup program
            self._initialize(mode)

    def _build(self):
        for mode in self._modes:
            serial_main_prog = self._serial_main_progs.get(mode, None)
            if serial_main_prog is not None:
                return

            losses = []
            metrics = []
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
142 143
            with static.program_guard(serial_main_prog, serial_startup_prog), \
                utils.unique_name.guard():
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
                inputs_spec = self.inputs_spec
                labels_spec = self.labels_spec if self.labels_spec else []
                inputs = [s._create_feed_layer() for s in inputs_spec]
                labels = [s._create_feed_layer() for s in labels_spec]
                outputs = to_list(self.model(*inputs))
                if mode != "predict" and self._loss:
                    losses = to_list(self._loss(*(outputs + labels)))

                if mode != "predict":
                    for metric in self._metrics:
                        metrics.extend(
                            to_list(metric.compute(*(outputs + labels))))

            default_ctx = get_default_distributed_context()
            if not default_ctx.has_annotation or self._default_strategy:
159 160 161 162
                # We build the world process group because the data parallel
                # needs all ranks by default.
                new_process_group(list(range(self._nranks)))
                default_ctx.data_parallel = True
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188

            # self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
            feed_vars = {"inputs": inputs, "labels": labels}

            # self._fetch_vars[mode] = {
            #     "outputs": flatten(outputs),
            #     "loss": losses,
            #     "metrics": metrics
            # }
            fetch_vars = {
                "outputs": flatten(outputs),
                "loss": losses,
                "metrics": metrics
            }

            self._dist_contexts[mode] = DistributedContext(
                serial_main_prog, serial_startup_prog, self._optimizer, losses,
                feed_vars, fetch_vars, self.cluster, self.strategy)
            self._dist_contexts[mode].gradient_scale = self._gradient_scale

    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

189 190
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
191 192

    def _parallel(self, mode, all_ranks):
193 194 195
        # Parallelize program based on the planner's results
        # For now, the completer has to be passed to the planner,
        # because we may use it to complete the annotation of the backwarkward and update.
196
        parallelizer = Parallelizer(mode, self._planners[mode].completer,
197 198 199 200 201
                                    self._dist_contexts[mode])
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
202 203

    def _init_dist_context(self, mode):
204
        # Init dist_context['mode'] with the first planned dist_context
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
                assert op.type == ref_op.type, \
                    "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type)
                ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program(
                    ref_op)
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

    def _initialize(self, mode):
222
        # Get the current content from the distributed context
223 224 225 226
        self._serial_main_progs[mode] = self._dist_contexts[
            mode].serial_main_program
        self._serial_startup_progs[mode] = self._dist_contexts[
            mode].serial_startup_program
227 228 229 230
        self._dist_main_progs[mode] = self._dist_contexts[
            mode].dist_main_programs
        self._dist_startup_progs[mode] = self._dist_contexts[
            mode].dist_startup_programs
231 232
        self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
        self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
233

234 235 236 237 238 239 240 241
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
            for process_group in all_process_groups:
                if self._cur_rank not in process_group.ranks:
                    continue
                process_group.instantiate()
242 243 244 245 246 247 248

        # initialize
        self._place = _get_device()
        if isinstance(self._place, fluid.CUDAPlace):
            self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
        if self._executor is None:
            self._executor = paddle.static.Executor(self._place)
249 250 251 252 253 254 255 256 257 258
            uninitialized = []
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
259

260 261 262 263
    def fit(self,
            train_data,
            batch_size=1,
            epochs=1,
264
            fetches=None,
265 266
            steps_per_epoch=None,
            use_program_cache=False,
267
            return_numpy=True):
268 269 270
        # TODO: callbacks
        # TODO: evaluate after training
        self.mode = 'train'
271
        assert self.mode in self._dist_main_progs, \
272
            "train model is not ready, please call `engine.prepare()` first."
273 274
        train_dataloader = self._create_dataloader(train_data, batch_size,
                                                   epochs, steps_per_epoch)
275

276 277 278 279
        usr_fetch = self._to_map_fetch(fetches)
        fetch_loss = self._inner_fetch(self.fetch_vars["loss"])
        fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)

280
        for epoch in range(epochs):
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
            train_logs = {"epoch": epoch}
            for step, _ in enumerate(train_dataloader):
                outs = self._executor.run(self.main_program,
                                          fetch_list=fetch_list,
                                          use_program_cache=use_program_cache,
                                          return_numpy=return_numpy)
                train_logs["step"] = step
                # inner fetches
                if fetch_loss:
                    train_logs["train_loss"] = outs[0][0]
                # user fetches
                user_outs = outs[len(fetch_loss):]
                user_fetch_list = fetch_list[len(fetch_loss):]
                for i, out in enumerate(user_outs):
                    train_logs["train_" +
                               fetch_map[user_fetch_list[i]]] = out[0]
297
                self._logger.info(train_logs)
298

299 300 301
    def evaluate(self,
                 eval_data,
                 batch_size=1,
302
                 fetches=None,
303
                 use_program_cache=False,
304
                 return_numpy=True):
305
        self.mode = 'eval'
306
        assert self.mode in self._dist_main_progs, \
307
            "eval model is not ready, please call `engine.prepare()` first."
308
        eval_dataloader = self._create_dataloader(eval_data, batch_size)
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338

        usr_fetch = self._to_map_fetch(fetches)
        fetch_loss = self._inner_fetch(self.fetch_vars["loss"])
        fetch_metrics = self._inner_fetch(self.fetch_vars["metrics"])
        inner_fetch = dict(fetch_loss, **fetch_metrics)
        fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)

        for step, _ in enumerate(eval_dataloader):
            eval_logs = {"step": step}
            outs = self._executor.run(self.main_program,
                                      fetch_list=fetch_list,
                                      use_program_cache=use_program_cache,
                                      return_numpy=return_numpy)
            # inner fetches
            if fetch_loss:
                eval_logs["eval_loss"] = outs[0]
            # Metric
            if fetch_metrics:
                metric_out = outs[len(fetch_loss):len(inner_fetch)]
                for metric in self._metrics:
                    metric.update(*metric_out)
                    results = metric.accumulate()
                    for i, res in enumerate(to_list(results)):
                        eval_logs["eval_" + metric.name()[i]] = res
            # usr fetches
            usr_out = outs[len(inner_fetch):]
            usr_fetch_list = fetch_list[len(inner_fetch):]
            for i, out in enumerate(usr_out):
                eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out
            # logger
339
            self._logger.info(eval_logs)
340

341 342 343
    def predict(self,
                test_data,
                batch_size=1,
344
                fetches=None,
345
                use_program_cache=False,
346
                return_numpy=True):
347
        self.mode = 'predict'
348
        assert self.mode in self._dist_main_progs, \
349
            "predict model is not ready, please call `engine.prepare()` first."
350
        test_dataloader = self._create_dataloader(test_data, batch_size)
351 352 353 354

        usr_fetch = self._to_map_fetch(fetches)
        fetch_outputs = self._inner_fetch(self.fetch_vars["outputs"])
        fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
355 356

        outputs = []
357 358 359 360 361 362 363 364 365
        for step, _ in enumerate(test_dataloader):
            predict_logs = {"step": step}
            outs = self._executor.run(self.main_program,
                                      fetch_list=fetch_list,
                                      use_program_cache=use_program_cache,
                                      return_numpy=return_numpy)
            outputs.append(outs[:len(fetch_outputs)])
            for i, out in enumerate(outs):
                predict_logs["pred_" + fetch_map[fetch_list[i]]] = out[0]
366
            self._logger.info(predict_logs)
367

368
        return outputs
369

370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
    def _local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

    def _to_map_fetch(self, fetches):
        if not fetches:
            return {}
        if isinstance(fetches, dict):
            fetch_var_names = list(map(_to_name_str, fetches.values()))
            usr_fetches = dict(zip(fetch_var_names, list(fetches.keys())))
        elif isinstance(fetches, list):
            fetch_var_names = list(map(_to_name_str, fetches))
            usr_fetches = dict(zip(fetch_var_names, fetch_var_names))
        return dict(filter(lambda x: self._local_var(x[0]),
                           usr_fetches.items()))

    def _inner_fetch(self, fetch_vars):
        fetch_list = list(
            map(lambda x: x.name, list(filter(self._local_var, fetch_vars))))
        inner_fetches = dict(zip(fetch_list, fetch_list))
        return inner_fetches

    def _fetch_map(self, inner_fetch, usr_fetch):
        # replace inner fetch name if usr set for it
        for iname in inner_fetch:
            if iname in usr_fetch:
                inner_fetch[iname] = usr_fetch[iname]
                usr_fetch.pop(iname)
        fetches = dict(inner_fetch, **usr_fetch)
        return list(fetches.keys()), fetches
400

401 402 403 404
    def _create_dataloader(self,
                           dataset,
                           batch_size,
                           epochs=1,
405
                           steps_per_epoch=None):
406 407 408 409
        dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
        dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
        dist_context = self._dist_contexts[self.mode]
        dist_main_block = dist_main_prog.global_block()
410

411
        # NOTE: Get feed_list from dist_program, then insert dataloader op
412 413
        # with sharded var shape. Because predict_program does not contain
        # labels var, so we will filter dataset's value with length of feed_list.
414 415 416 417 418 419
        inputs_var = self._feed_vars[self.mode]["inputs"]
        labels_var = self._feed_vars[self.mode]["labels"]
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
420 421
        dp_world_size, dp_rank = self._get_data_parallel_info(
            feed_list[0], dist_context)
422 423

        # remove the first three ops if multi run fit/evaluate/predict
424
        op_size = len(dist_main_block.ops)
425 426 427 428
        if dist_main_block.ops[0].type == 'create_py_reader':
            op_size -= 3
            for _ in range(3):
                dist_main_block._remove_op(0, sync=False)
429 430

        # insert read op at the end of program
431
        places = paddle.static.cuda_places()
432
        with static.program_guard(dist_main_prog, dist_startup_prog):
433
            dataloader = NonIterableGeneratorLoader(
434 435 436 437 438 439
                dataset,
                feed_list,
                places,
                batch_size,
                epochs,
                steps_per_epoch,
440 441 442 443
                data_parallel_world_size=dp_world_size,
                data_parallel_rank=dp_rank)

        # move read op from the end of program to the start of program
444
        new_op_size = len(dist_main_block.ops)
445
        for _ in range(new_op_size - 1, op_size - 1, -1):
446 447 448
            op = dist_main_block.ops[new_op_size - 1]
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(op.desc)
449 450 451
            new_op = Operator(dist_main_block,
                              new_op_desc,
                              type=new_op_desc.type())
452 453 454 455 456 457 458 459
            dist_main_block.ops.insert(0, new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        for _ in range(new_op_size - op_size):
            dist_main_block._remove_op(new_op_size, sync=False)
        dist_main_block._sync_with_cpp()
        return dataloader

460 461 462 463 464 465 466 467 468 469 470
    def _validate_spec(self, specs):
        specs = to_list(specs)
        if specs is not None:
            for i, spec in enumerate(specs):
                assert isinstance(spec, InputSpec)
                if spec.name is None:
                    raise ValueError(
                        "Requires Input[{}].name != None, but receive `None` with {}."
                        .format(i, spec))
        return specs

471 472 473
    def _set_data_parallel(self, var):
        if self._nranks == 1:
            self._default_strategy = 'serial'
474 475 476 477 478 479
            auto.shard_tensor(var,
                              dist_attr={
                                  "process_mesh": [0],
                                  "dims_mapping":
                                  [-1 for _ in range(len(var.shape))]
                              })
480 481
        else:
            self._default_strategy = 'dp'
482 483 484 485 486 487 488
            auto.shard_tensor(var,
                              dist_attr={
                                  "process_mesh":
                                  list(range(self._nranks)),
                                  "dims_mapping":
                                  [0] + [-1 for _ in range(len(var.shape) - 1)]
                              })
489 490 491

        return var

492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
    def _get_data_parallel_info(self, var, dist_context):
        # get data parallel world size and current data parallel rank
        from .utils import _get_comm_group, _get_corresponding_rank

        tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        process_mesh = tensor_dist_attr.process_mesh
        dims_mapping = tensor_dist_attr.dims_mapping

        if self._cur_rank not in process_mesh.processes:
            rank_id = _get_corresponding_rank(dist_context, process_mesh,
                                              self._cur_rank)
        else:
            rank_id = self._cur_rank

        batch_size_axis = dims_mapping[0]
        if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
            group_ranks = _get_comm_group(process_mesh.processes,
                                          process_mesh.topology,
                                          batch_size_axis, rank_id)
            return len(group_ranks), group_ranks.index(rank_id)

        return None, None

515 516 517 518 519
    def save(self, path, training=True, mode=None):
        if not mode:
            mode = self.mode

        if training:
520 521
            assert 'train' in self._serial_main_progs, \
                "training model is not ready, please call `engine.prepare()` first."
522 523 524
            serial_program = self._serial_main_progs["train"]
            dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
            dist_context = self._dist_contexts["train"]
525 526 527 528
            self._saver.save(path,
                             serial_program=serial_program,
                             dist_main_program=dist_main_prog,
                             dist_context=dist_context)
529 530 531 532 533
        else:
            assert mode, "Please set the 'mode' you want to save."
            feed_vars = self._feed_vars[mode]['inputs']
            fetch_vars = self._fetch_vars[mode]['outputs']
            dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
534 535 536 537 538
            self._saver.save_inference_model(path,
                                             feed_vars,
                                             fetch_vars,
                                             self._executor,
                                             program=dist_main_prog)
539

540 541 542 543
    def load(self, path, strict=True, load_optimizer=True, mode=None):
        if not mode:
            mode = self.mode
        assert mode, "Please set the 'mode' you want to load."
544

545 546 547 548
        dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
        dist_context = self._dist_contexts[mode]
        self._saver.load(path, dist_main_prog, dist_context, strict,
                         load_optimizer)
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576

    @property
    def mode(self):
        return self._mode

    @mode.setter
    def mode(self, mode):
        self._mode = mode

    @property
    def main_program(self):
        return self._dist_main_progs[self.mode][self._cur_rank]

    @property
    def startup_program(self):
        return self._dist_startup_progs[self.mode][self._cur_rank]

    @property
    def dist_context(self):
        return self._dist_contexts[self.mode]

    @property
    def serial_main_program(self):
        return self._serial_main_progs[self.mode]

    @property
    def serial_startup_program(self):
        return self._serial_startup_progs[self.mode]
577 578 579 580

    @property
    def fetch_vars(self):
        return self._fetch_vars[self.mode]