engine.py 26.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
from collections import defaultdict
18
import socket
19 20

import paddle
21
import paddle.utils as utils
22

23
from paddle import fluid, static
24
from paddle.io import Dataset
25
from paddle.jit import to_static
26
from paddle.metric import Metric
27
from paddle.static import InputSpec
28
from paddle.fluid import core
29
from paddle.fluid import program_guard
30
from paddle.fluid.layers.utils import flatten
31
from paddle.fluid.executor import global_scope, _to_name_str
32
from paddle.fluid.backward import append_backward
33
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
34 35
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
36
from paddle.distributed import fleet
37
from paddle.distributed.utils import get_logger
38
from paddle.distributed.passes import new_pass, PassContext
39

40
from .hepler import ProgramHelper
41 42
from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster
43 44
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
45 46 47 48 49
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
50
from .process_group import new_process_group, get_all_process_groups, get_world_process_group
51
from .dist_context import DistributedContext, get_default_distributed_context
52 53 54


class Engine:
55

56 57 58 59 60 61
    def __init__(self,
                 model=None,
                 inputs_spec=None,
                 labels_spec=None,
                 cluster=None,
                 strategy=None):
62
        self.model = model
63 64
        self.inputs_spec = self._validate_spec(inputs_spec)
        self.labels_spec = self._validate_spec(labels_spec)
65
        self.cluster = cluster
66 67
        if self.cluster is None:
            self.cluster = get_default_cluster()
68
        self.strategy = strategy
69 70
        if self.strategy is None:
            self.strategy = fleet.DistributedStrategy()
71

72
        self._executor = None
73 74 75 76 77
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
        self._logger = get_logger(logging.INFO)

78 79
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
80
        self._orig_dist_context = get_default_distributed_context()
81
        self._dist_contexts = {}
82 83
        self._serial_main_progs = {}
        self._serial_startup_progs = {}
84 85 86 87
        self._dist_main_progs = defaultdict(dict)  # dist main programs
        self._dist_startup_progs = defaultdict(dict)  # dist startup programs
        self._feed_vars = {}
        self._fetch_vars = {}
88
        self._planners = {}
89 90 91 92 93
        self._mode_init_states = {
            "train": False,
            "eval": False,
            "predict": False
        }
94
        self._dygraph_mode = False
95 96 97 98

    def prepare(self,
                optimizer=None,
                loss=None,
99
                gradient_scale=True,
100 101
                metrics=None,
                all_ranks=False):
102 103 104
        if optimizer and not isinstance(
                optimizer,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
105 106 107 108
            raise TypeError(
                    "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
                        " or `paddle.fluid.optimizer.Optimizer`."
                )
109
        self._optimizer = optimizer
110
        self._all_ranks = all_ranks
111 112 113 114 115 116

        if loss and not isinstance(loss,
                                   paddle.nn.Layer) and not callable(loss):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
            )
117
        self._loss = loss
118 119 120 121 122 123

        metrics = metrics or []
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(
                    metric.__class__.__name__)
124
        self._metrics = to_list(metrics)
125
        self._gradient_scale = gradient_scale
126
        self._planned_mode = None
127
        self._prepare_single_mode("train")
128

129 130 131 132
    def _prepare_single_mode(self, mode):
        self._modes = [mode]
        self._build(self._modes[0])
        # Do auto parallel process
133 134 135
        for mode in self._modes:
            # Do the planning process
            self._plan(mode)
136
        for mode in self._modes:
137
            # Do the parallel process
138 139
            self._parallel(mode, self._all_ranks)

140 141
            # Init comm and startup program
            self._initialize(mode)
142
            self._mode_init_states[mode] = True
143

144
    def _build(self, mode):
145
        if _non_static_mode() or self._dygraph_mode:
146
            paddle.disable_static()
147 148 149
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

150 151 152
            program_helper = ProgramHelper(self.model, self._loss,
                                           self._metrics, self.inputs_spec,
                                           self.labels_spec)
153
            # build forward main program
154
            program_helper.build_program(mode)
155

156 157 158
            self.concrete_program = program_helper.concrete_program
            serial_main_prog = program_helper.main_program
            serial_startup_prog = program_helper.startup_program
159

160 161 162 163 164
            inputs = program_helper.input_vars
            outputs = program_helper.output_vars
            labels = program_helper.label_vars
            losses = program_helper.loss_vars
            metrics = program_helper.metric_vars
165

166
            paddle.enable_static()
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
        else:
            # build program in static mode
            serial_main_prog = self._serial_main_progs.get(mode, None)
            if serial_main_prog is not None:
                return

            losses = []
            metrics = []
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
            with static.program_guard(serial_main_prog, serial_startup_prog), \
                utils.unique_name.guard():
                inputs_spec = self.inputs_spec
                labels_spec = self.labels_spec if self.labels_spec else []
                inputs = [s._create_feed_layer() for s in inputs_spec]
                labels = [s._create_feed_layer() for s in labels_spec]
                outputs = to_list(self.model(*inputs))
                if mode != "predict" and self._loss:
                    losses = to_list(self._loss(*(outputs + labels)))

                if mode != "predict":
                    for metric in self._metrics:
                        metrics.extend(
                            to_list(metric.compute(*(outputs + labels))))
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True

        feed_vars = {"inputs": inputs, "labels": labels}

        fetch_vars = {
            "outputs": flatten(outputs),
            "loss": losses,
            "metrics": metrics
        }

        self._dist_contexts[mode] = DistributedContext(
            serial_main_prog, serial_startup_prog, self._optimizer, losses,
            feed_vars, fetch_vars, self.cluster, self.strategy)
        self._dist_contexts[mode].gradient_scale = self._gradient_scale
211
        self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode
212 213 214 215 216 217 218

    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

219 220
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
221 222

    def _parallel(self, mode, all_ranks):
223 224 225
        # Parallelize program based on the planner's results
        # For now, the completer has to be passed to the planner,
        # because we may use it to complete the annotation of the backwarkward and update.
226
        parallelizer = Parallelizer(mode, self._planners[mode].completer,
227 228 229 230 231
                                    self._dist_contexts[mode])
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
232 233

    def _init_dist_context(self, mode):
234
        # Init dist_context['mode'] with the first planned dist_context
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
                assert op.type == ref_op.type, \
                    "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type)
                ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program(
                    ref_op)
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

    def _initialize(self, mode):
252
        # Get the current content from the distributed context
253 254 255 256
        self._serial_main_progs[mode] = self._dist_contexts[
            mode].serial_main_program
        self._serial_startup_progs[mode] = self._dist_contexts[
            mode].serial_startup_program
257 258 259 260
        self._dist_main_progs[mode] = self._dist_contexts[
            mode].dist_main_programs
        self._dist_startup_progs[mode] = self._dist_contexts[
            mode].dist_startup_programs
261 262
        self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
        self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
263

264 265 266 267
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
268

269
            # NOTE: add the comm init control in the future for auto search
270 271 272 273
            for process_group in all_process_groups:
                if self._cur_rank not in process_group.ranks:
                    continue
                process_group.instantiate()
274 275 276 277

        self._place = _get_device()
        if isinstance(self._place, fluid.CUDAPlace):
            self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306

        if self._dygraph_mode:
            paddle.disable_static()
            main_program = self._dist_main_progs[mode][self._cur_rank]
            for param in self.concrete_program.parameters:
                # create var in scope and share parameters to scope
                if param.name not in main_program.global_block().vars:
                    continue
                # get param_var's dist_attr
                var = main_program.global_block().vars[param.name]
                var_dist_attr = self._dist_contexts[
                    mode].get_tensor_dist_attr_for_program(var)
                dist_attr = {
                    "dims_mapping": var_dist_attr.dims_mapping,
                    "process_shape": var_dist_attr.process_mesh.topology,
                    "process_group": var_dist_attr.process_mesh.processes
                }
                # slice param_value with dist_attr
                # share sliced_param_value with param_tensor in global_scope
                from .converter import Converter
                param_tensor = global_scope().var(param.name).get_tensor()
                sliced_param = Converter.slice_with_dist_attr(
                    param.numpy(), dist_attr)
                shared_tensor = paddle.to_tensor(sliced_param,
                                                 place=self._place)
                param_tensor._share_data_with(
                    shared_tensor.value().get_tensor())
            paddle.enable_static()

307 308
        if self._executor is None:
            self._executor = paddle.static.Executor(self._place)
309 310 311 312 313 314 315 316 317 318
            uninitialized = []
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
319

320 321 322 323
    def fit(self,
            train_data,
            batch_size=1,
            epochs=1,
324
            fetches=None,
325 326
            steps_per_epoch=None,
            use_program_cache=False,
327
            return_numpy=True):
328 329
        # TODO: callbacks
        # TODO: evaluate after training
330 331 332 333 334 335

        if not self._mode_init_states['train']:
            raise Exception(
                "train program is not initialized yet, please call engine.prepare() before calling fit() funtion."
            )

336
        self.mode = 'train'
337
        assert self.mode in self._dist_main_progs, \
338
            "train model is not ready, please call `engine.prepare()` first."
339 340
        train_dataloader = self._create_dataloader(train_data, batch_size,
                                                   epochs, steps_per_epoch)
341

342 343
        usr_fetch = self._validate_fetches(fetches)
        fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
344 345
        fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)

346
        for epoch in range(epochs):
347 348 349 350 351 352 353 354 355 356 357 358 359 360
            train_logs = {"epoch": epoch}
            for step, _ in enumerate(train_dataloader):
                outs = self._executor.run(self.main_program,
                                          fetch_list=fetch_list,
                                          use_program_cache=use_program_cache,
                                          return_numpy=return_numpy)
                train_logs["step"] = step
                # inner fetches
                if fetch_loss:
                    train_logs["train_loss"] = outs[0][0]
                # user fetches
                user_outs = outs[len(fetch_loss):]
                user_fetch_list = fetch_list[len(fetch_loss):]
                for i, out in enumerate(user_outs):
361
                    train_logs["train_" + fetch_map[user_fetch_list[i]]] = out
362
                self._logger.info(train_logs)
363

364 365 366
    def evaluate(self,
                 eval_data,
                 batch_size=1,
367
                 fetches=None,
368
                 use_program_cache=False,
369
                 return_numpy=True):
370
        self.mode = 'eval'
371 372 373
        if not self._mode_init_states[self.mode]:
            self._prepare_single_mode(self.mode)

374
        assert self.mode in self._dist_main_progs, \
375
            "eval model is not ready, please call `engine.prepare()` first."
376
        eval_dataloader = self._create_dataloader(eval_data, batch_size)
377

378 379 380
        usr_fetch = self._validate_fetches(fetches)
        fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
        fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
381 382 383 384 385 386 387 388 389 390 391
        inner_fetch = dict(fetch_loss, **fetch_metrics)
        fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)

        for step, _ in enumerate(eval_dataloader):
            eval_logs = {"step": step}
            outs = self._executor.run(self.main_program,
                                      fetch_list=fetch_list,
                                      use_program_cache=use_program_cache,
                                      return_numpy=return_numpy)
            # inner fetches
            if fetch_loss:
392
                eval_logs["eval_loss"] = outs[0][0]
393 394 395 396 397 398 399 400 401
            # Metric
            if fetch_metrics:
                metric_out = outs[len(fetch_loss):len(inner_fetch)]
                for metric in self._metrics:
                    metric.update(*metric_out)
                    results = metric.accumulate()
                    for i, res in enumerate(to_list(results)):
                        eval_logs["eval_" + metric.name()[i]] = res
            # usr fetches
402
            usr_outs = outs[len(inner_fetch):]
403
            usr_fetch_list = fetch_list[len(inner_fetch):]
404
            for i, out in enumerate(usr_outs):
405 406
                eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out
            # logger
407
            self._logger.info(eval_logs)
408

409 410 411
    def predict(self,
                test_data,
                batch_size=1,
412
                fetches=None,
413
                use_program_cache=False,
414
                return_numpy=True):
415
        self.mode = 'predict'
416 417 418
        if not self._mode_init_states[self.mode]:
            self._prepare_single_mode(self.mode)

419
        assert self.mode in self._dist_main_progs, \
420
            "predict model is not ready, please call `engine.prepare()` first."
421
        test_dataloader = self._create_dataloader(test_data, batch_size)
422

423 424
        usr_fetch = self._validate_fetches(fetches)
        fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"])
425
        fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
426 427

        outputs = []
428 429 430 431 432 433 434 435
        for step, _ in enumerate(test_dataloader):
            predict_logs = {"step": step}
            outs = self._executor.run(self.main_program,
                                      fetch_list=fetch_list,
                                      use_program_cache=use_program_cache,
                                      return_numpy=return_numpy)
            outputs.append(outs[:len(fetch_outputs)])
            for i, out in enumerate(outs):
436
                predict_logs["pred_" + fetch_map[fetch_list[i]]] = out
437
            self._logger.info(predict_logs)
438

439
        return outputs
440

441 442 443 444
    def _create_dataloader(self,
                           dataset,
                           batch_size,
                           epochs=1,
445
                           steps_per_epoch=None):
446 447 448 449
        dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
        dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
        dist_context = self._dist_contexts[self.mode]
        dist_main_block = dist_main_prog.global_block()
450

451
        # NOTE: Get feed_list from dist_program, then insert dataloader op
452 453
        # with sharded var shape. Because predict_program does not contain
        # labels var, so we will filter dataset's value with length of feed_list.
454 455 456 457 458 459
        inputs_var = self._feed_vars[self.mode]["inputs"]
        labels_var = self._feed_vars[self.mode]["labels"]
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
460 461
        dp_world_size, dp_rank = self._get_data_parallel_info(
            feed_list[0], dist_context)
462 463

        # remove the first three ops if multi run fit/evaluate/predict
464
        op_size = len(dist_main_block.ops)
465 466 467 468
        if dist_main_block.ops[0].type == 'create_py_reader':
            op_size -= 3
            for _ in range(3):
                dist_main_block._remove_op(0, sync=False)
469 470

        # insert read op at the end of program
471
        places = paddle.static.cuda_places()
472
        with static.program_guard(dist_main_prog, dist_startup_prog):
473
            dataloader = NonIterableGeneratorLoader(
474 475 476 477 478 479
                dataset,
                feed_list,
                places,
                batch_size,
                epochs,
                steps_per_epoch,
480 481 482 483
                data_parallel_world_size=dp_world_size,
                data_parallel_rank=dp_rank)

        # move read op from the end of program to the start of program
484
        new_op_size = len(dist_main_block.ops)
485
        for _ in range(new_op_size - 1, op_size - 1, -1):
486 487 488
            op = dist_main_block.ops[new_op_size - 1]
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(op.desc)
489 490 491
            new_op = Operator(dist_main_block,
                              new_op_desc,
                              type=new_op_desc.type())
492 493 494 495 496 497 498 499
            dist_main_block.ops.insert(0, new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        for _ in range(new_op_size - op_size):
            dist_main_block._remove_op(new_op_size, sync=False)
        dist_main_block._sync_with_cpp()
        return dataloader

500 501 502 503 504 505 506 507 508 509 510
    def _validate_spec(self, specs):
        specs = to_list(specs)
        if specs is not None:
            for i, spec in enumerate(specs):
                assert isinstance(spec, InputSpec)
                if spec.name is None:
                    raise ValueError(
                        "Requires Input[{}].name != None, but receive `None` with {}."
                        .format(i, spec))
        return specs

511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

    def _validate_fetches(self, fetches):
        # 1. Check user-defined fetches type
        # 2. Prepare fetches_dict like {user_defined_name: var_name}
        if not fetches:
            return {}
        if isinstance(fetches, dict):
            fetch_var_names = list(map(_to_name_str, fetches.values()))
            fetches_dict = dict(zip(fetch_var_names, list(fetches.keys())))
        elif isinstance(fetches, list):
            fetch_var_names = list(map(_to_name_str, fetches))
            fetches_dict = dict(zip(fetch_var_names, fetch_var_names))
526
        else:
527 528 529 530 531 532 533 534 535 536 537 538 539
            raise TypeError("'fetches' only support 'dict' and 'list', "
                            "but got '{}'".format(str(type(fetches))))
        return dict(
            filter(lambda x: self._is_local_var(x[0]), fetches_dict.items()))

    def _fetch_map(self, inner_fetch, usr_fetch):
        # replace inner fetch name if usr set for it
        for iname in inner_fetch:
            if iname in usr_fetch:
                inner_fetch[iname] = usr_fetch[iname]
                usr_fetch.pop(iname)
        fetches = dict(inner_fetch, **usr_fetch)
        return list(fetches.keys()), fetches
540

541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
    def _get_data_parallel_info(self, var, dist_context):
        # get data parallel world size and current data parallel rank
        from .utils import _get_comm_group, _get_corresponding_rank

        tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        process_mesh = tensor_dist_attr.process_mesh
        dims_mapping = tensor_dist_attr.dims_mapping

        if self._cur_rank not in process_mesh.processes:
            rank_id = _get_corresponding_rank(dist_context, process_mesh,
                                              self._cur_rank)
        else:
            rank_id = self._cur_rank

        batch_size_axis = dims_mapping[0]
        if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
            group_ranks = _get_comm_group(process_mesh.processes,
                                          process_mesh.topology,
                                          batch_size_axis, rank_id)
            return len(group_ranks), group_ranks.index(rank_id)

        return None, None

564 565 566 567 568
    def save(self, path, training=True, mode=None):
        if not mode:
            mode = self.mode

        if training:
569 570
            assert 'train' in self._serial_main_progs, \
                "training model is not ready, please call `engine.prepare()` first."
571 572 573
            serial_program = self._serial_main_progs["train"]
            dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
            dist_context = self._dist_contexts["train"]
574 575 576 577
            self._saver.save(path,
                             serial_program=serial_program,
                             dist_main_program=dist_main_prog,
                             dist_context=dist_context)
578 579 580 581 582
        else:
            assert mode, "Please set the 'mode' you want to save."
            feed_vars = self._feed_vars[mode]['inputs']
            fetch_vars = self._fetch_vars[mode]['outputs']
            dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
583 584 585 586 587
            self._saver.save_inference_model(path,
                                             feed_vars,
                                             fetch_vars,
                                             self._executor,
                                             program=dist_main_prog)
588

589 590 591 592
    def load(self, path, strict=True, load_optimizer=True, mode=None):
        if not mode:
            mode = self.mode
        assert mode, "Please set the 'mode' you want to load."
593

594 595 596 597
        dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
        dist_context = self._dist_contexts[mode]
        self._saver.load(path, dist_main_prog, dist_context, strict,
                         load_optimizer)
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625

    @property
    def mode(self):
        return self._mode

    @mode.setter
    def mode(self, mode):
        self._mode = mode

    @property
    def main_program(self):
        return self._dist_main_progs[self.mode][self._cur_rank]

    @property
    def startup_program(self):
        return self._dist_startup_progs[self.mode][self._cur_rank]

    @property
    def dist_context(self):
        return self._dist_contexts[self.mode]

    @property
    def serial_main_program(self):
        return self._serial_main_progs[self.mode]

    @property
    def serial_startup_program(self):
        return self._serial_startup_progs[self.mode]
626 627 628 629

    @property
    def fetch_vars(self):
        return self._fetch_vars[self.mode]