compiler.py 51.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
import os
X
polish  
Xin Pan 已提交
17
import sys
18
import warnings
X
Xin Pan 已提交
19
from . import framework
20
from .framework import _get_paddle_place, _get_paddle_place_list
21
from .framework import cuda_places, cpu_places, xpu_places
22 23
from . import core

J
jianghaicheng 已提交
24 25 26 27
__all__ = [
    'CompiledProgram', 'ExecutionStrategy', 'BuildStrategy',
    'IpuCompiledProgram', 'IpuStrategy'
]
X
Xin Pan 已提交
28

29 30
ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy
F
flame 已提交
31 32
InferNativeConfig = core.NativeConfig
InferAnalysisConfig = core.AnalysisConfig
33
DeviceType = core.DeviceType
34 35 36 37 38 39 40 41


def _place_obj(place):
    p = core.Place()
    p.set_place(place)
    return p


42 43
def _is_pserver_mode(main_program):
    main = main_program if main_program \
C
chengduo 已提交
44
        else framework.default_main_program()
45 46 47 48 49 50
    for op in main.global_block().ops:
        if op.type in ["send", "recv"]:
            return True
    return False


C
chengduo 已提交
51 52 53 54 55 56 57 58
def _has_backward_op(graph):
    for node in graph.nodes():
        if node.is_op() and node.op() is not None and \
                node.op().type().endswith("_grad"):
            return True
    return False


59 60 61 62 63 64 65 66 67
def _prune_feed_ops(program):
    # prune the feed ops in the program.
    pop_idx = []
    for i, op in enumerate(program.global_block().ops):
        if op.type == "feed": pop_idx.append(i)
    for index in pop_idx[::-1]:
        program.global_block()._remove_op(index)


68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
def _has_optimize_op(block):
    for op in block.ops:
        op_maker = core.op_proto_and_checker_maker
        optimize = core.op_proto_and_checker_maker.OpRole.Optimize
        if op_maker.kOpRoleVarAttrName() in op.attr_names and \
                int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize):
            return True
    return False


def _has_optimizer_in_control_flow(program):
    if not program:
        program = framework.default_main_program()
    for op in program.global_block().ops:
        if op.type == "conditional_block_grad":
            sub_block = program.block(op._block_attr_id("sub_block"))
            if _has_optimize_op(sub_block):
                return True

    return False


90 91 92 93 94 95 96 97 98 99
def _should_broadcast_or_not_exists(program, var_name):
    block = program.global_block()
    var = block.vars.get(var_name, None)
    if var is None:
        return True
    is_distributed = getattr(var, '_is_distributed', False) or getattr(
        var, 'is_distributed', False)
    return not is_distributed


X
polish  
Xin Pan 已提交
100
class CompiledProgram(object):
X
polish  
Xin Pan 已提交
101
    """
102
    :api_attr: Static Graph
103

C
chengduo 已提交
104 105 106 107 108
    The CompiledProgram is used to transform a program or graph for
    various optimizations according to the configuration of build_strategy,
    for example, the operators' fusion in the computation graph, memory
    optimization during the execution of the computation graph, etc.
    For more information about build_strategy, please refer to
109
    :code:`paddle.static.BuildStrategy`.
X
polish  
Xin Pan 已提交
110

C
chengduo 已提交
111
    Args:
112
        program_or_graph (Graph|Program): This argument is the Program or Graph
C
chengduo 已提交
113
            being executed.
114
        build_strategy(BuildStrategy): This argument is used to compile the
C
chengduo 已提交
115 116 117
            program or graph with the specified options, such as operators' fusion
            in the computational graph and memory optimization during the execution
            of the computational graph. For more information about build_strategy,
118
            please refer to :code:`paddle.static.BuildStrategy`. The default is None.
X
Xin Pan 已提交
119

C
chengduo 已提交
120 121
    Returns:
        CompiledProgram
X
polish  
Xin Pan 已提交
122 123

    Example:
X
Xin Pan 已提交
124
        .. code-block:: python
125

126 127 128
            import numpy
            import paddle
            import paddle.static as static
129

130
            paddle.enable_static()
131

132 133
            place = paddle.CUDAPlace(0) # paddle.CPUPlace()
            exe = static.Executor(place)
134

135
            data = static.data(name='X', shape=[None, 1], dtype='float32')
136
            hidden = static.nn.fc(x=data, size=10)
137 138
            loss = paddle.mean(hidden)
            paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)
139

140 141 142 143 144 145 146 147
            exe.run(static.default_startup_program())
            compiled_prog = static.CompiledProgram(
                static.default_main_program())

            x = numpy.random.random(size=(10, 1)).astype('float32')
            loss_data, = exe.run(compiled_prog,
                                feed={"X": x},
                                fetch_list=[loss.name])
X
polish  
Xin Pan 已提交
148 149
    """

C
chengduo 已提交
150
    def __init__(self, program_or_graph, build_strategy=None):
X
Xin Pan 已提交
151 152
        if isinstance(program_or_graph, core.Graph):
            self._graph = program_or_graph
153
            # don't not create a new program here.
X
Xin Pan 已提交
154 155
            self._program = None
        elif isinstance(program_or_graph, framework.Program):
156
            _prune_feed_ops(program_or_graph)
X
Xin Pan 已提交
157 158 159
            self._graph = core.Graph(program_or_graph.desc)
            self._program = program_or_graph
        else:
160 161 162
            raise TypeError(
                "The type of program_to_graph parameter is wrong, expected Graph or Program, but received %s"
                % type(program_or_graph))
X
Xin Pan 已提交
163

X
polish  
Xin Pan 已提交
164 165 166
        self._scope = None
        self._place = None
        self._executor = None
167 168
        self._compiled = False
        self._is_data_parallel = False
F
flame 已提交
169
        self._is_inference = False
C
chengduo 已提交
170 171 172 173 174
        self._loss_name = None
        self._share_vars_from = None
        self._places = None
        self._build_strategy = build_strategy
        self._exec_strategy = None
175

X
Xin Pan 已提交
176 177 178 179
    def with_data_parallel(self,
                           loss_name=None,
                           build_strategy=None,
                           exec_strategy=None,
S
sneaxiy 已提交
180 181
                           share_vars_from=None,
                           places=None):
C
chengduo 已提交
182 183 184 185 186 187
        """
        This interface is used to transform the input Program or Graph to a multi-graph
        to run the model in data parallel mode. Users can use the build_strategy and
        exec_strategy to set some optimizations that can be applied during the construction
        and computation of the Graph, such as reducing the number of AllReduce operations,
        specifying the size of the thread pool used in the computation Graph running the model,
188 189
        and so on.

190
        .. note::
191 192 193
            If build_strategy is specified when building CompiledProgram and calling
            with_data_parallel, build_strategy in CompiledProgram will be overwritten, therefore,
            if it is data parallel training, it is recommended to set build_strategy when calling
194
            with_data_parallel interface.
C
chengduo 已提交
195 196

        Args:
197
            loss_name (str): This parameter is the name of the loss Tensor of the model.
C
chengduo 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
                **Note: If it is model training, you must set loss_name, otherwise the
                result may be problematic**. The default is None.
            build_strategy(BuildStrategy): This parameter is used to compile the
                program or graph with the specified options, such as operators' fusion
                in the computational graph and memory optimization during the execution
                of the computational graph. For more information about build_strategy,
                please refer to :code:`fluid.BuildStrategy`. The default is None.
            exec_strategy(ExecutionStrategy): exec_strategy specifies the options that can
                be changed when running the current model, such as the thread pool size.
                For more information about exec_strategy, please refer to :code:`fluid.ExecutionStrategy`.
                The default is None.
            share_vars_from(CompiledProgram): If share_vars_from is set, the current
                CompiledProgram will share the parameter value with the CompiledProgram
                specified by share_vars_from. This parameter needs to be set when model testing
                is required during model training, and the data parallel mode is used for
                training and testing. Since CompiledProgram will only distribute parameter
214
                Tensors to other devices when it is first executed, the CompiledProgram
C
chengduo 已提交
215 216
                specified by share_vars_from must be run before the current CompiledProgram.
                The default is None.
217
            places(list(CUDAPlace)|list(CPUPlace)|list(str)|None): This parameter specifies the device
C
chengduo 已提交
218 219 220 221 222 223 224 225 226 227
                on which the model is running. If you want to run on GPU0 and GPU1, places are
                [fluid.CUDAPlace(0), fluid.CUDAPlace(1)]; if you want to run with 2 CPUs, places are
                [fluid.CPUPlace()] * 2. If the parameter is not set, i.e. the parameter is None,
                the available device will be obtained from the environment variable when the model
                is executed: If the GPU is used, the currently available device ID is obtained
                from the environment variable FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES when
                the model is executed; CPU, when the model is executed, the currently available
                CPU number is obtained from the environment variable CPU_NUM. For example,
                export CPU_NUM=4, if the environment variable is not set, the executor will
                add the variable to the environment variable and set its value to 1.
228
                The default is None. If ``places`` is the list of string, the string in the list
229
                can be ``cpu``, ``gpu:x``, where ``x`` is the index of the GPUs.
C
chengduo 已提交
230 231 232

        Returns:
            CompiledProgram
X
Xin Pan 已提交
233

234 235 236
        Example:
            .. code-block:: python

237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
                import numpy
                import os
                import paddle
                import paddle.static as static

                paddle.enable_static()

                use_cuda = True
                place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
                parallel_places = [paddle.CUDAPlace(0), paddle.CUDAPlace(1)] if use_cuda else [paddle.CPUPlace()] * 2

                # NOTE: If you use CPU to run the program, you need
                # to specify the CPU_NUM, otherwise, paddle will use
                # all the number of the logic core as the CPU_NUM,
                # in that case, the batch size of the input should be
                # greater than CPU_NUM, if not, the process will be
                # failed by an exception.
                if not use_cuda:
                    os.environ['CPU_NUM'] = str(2)

                exe = static.Executor(place)

                data = static.data(name='X', shape=[None, 1], dtype='float32')
260
                hidden = static.nn.fc(x=data, size=10)
261 262 263 264 265 266 267 268 269 270
                loss = paddle.mean(hidden)

                test_program = static.default_main_program().clone(for_test=True)
                paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)

                exe.run(static.default_startup_program())
                compiled_train_prog = static.CompiledProgram(
                    static.default_main_program()).with_data_parallel(
                            loss_name=loss.name, places=parallel_places)
                # NOTE: if not set share_vars_from=compiled_train_prog,
271
                # the parameters used in test process are different with
272 273 274 275 276 277 278 279
                # the parameters used by train process
                compiled_test_prog = static.CompiledProgram(
                    test_program).with_data_parallel(
                            share_vars_from=compiled_train_prog,
                            places=parallel_places)

                train_data = numpy.random.random(size=(10, 1)).astype('float32')
                loss_data, = exe.run(compiled_train_prog,
280 281
                                feed={"X": train_data},
                                fetch_list=[loss.name])
282 283
                test_data = numpy.random.random(size=(10, 1)).astype('float32')
                loss_data, = exe.run(compiled_test_prog,
284 285
                                feed={"X": test_data},
                                fetch_list=[loss.name])
X
Xin Pan 已提交
286
        """
287 288
        assert not self._is_data_parallel, "Already compiled with parallel, cannot be recompiled."
        assert not self._is_inference, "Cannot compile with both data parallel and inference."
289
        self._is_data_parallel = True
C
chengduo 已提交
290 291 292 293 294
        # FIXME(zcd): Currently, the build_strategy can be set during creating
        # CompiledProgram or calling with_data_parallel, and it may be confusing,
        # but in the long run, we should set up build_strategy only when creating
        # CompiledProgram, and exec_strategy should be deprecated.
        if build_strategy is not None: self._build_strategy = build_strategy
295 296
        self._exec_strategy = exec_strategy
        self._loss_name = loss_name
X
polish  
Xin Pan 已提交
297
        self._share_vars_from = share_vars_from
298 299 300 301
        if isinstance(places, (list, tuple)):
            self._places = _get_paddle_place_list(places)
        else:
            self._places = _get_paddle_place(places)
C
chengduo 已提交
302 303

        if _has_backward_op(self._graph):
304
            assert self._loss_name is not None, "The loss name of CompiledProgram is None. The loss name should be set if CompiledProgram contains backward part."
C
chengduo 已提交
305 306 307 308 309

        if self._places is not None:
            if not isinstance(self._places, (list, tuple)):
                self._places = [self._places]

310 311
        return self

F
flame 已提交
312
    def _with_inference_optimize(self, config):
F
flame 已提交
313 314 315 316 317 318 319
        """ Add inference optimize

        Args:
            config: instance of `NativeConfig` or `AnalysisConfig` to create predictor
        Returns:
            self
        """
320 321
        assert not self._is_data_parallel, "Cannot compile with both data parallel and inference"
        assert not self._is_inference, "Already compiled with inference, cannot be recompiled."
X
Xin Pan 已提交
322

F
flame 已提交
323 324 325 326 327 328 329
        assert any([
            isinstance(config, InferNativeConfig),
            isinstance(config, InferAnalysisConfig)
        ])
        self._is_inference = True
        self._infer_config = config
        return self
X
polish  
Xin Pan 已提交
330

F
flame 已提交
331
    def _with_distributed(self):
332 333 334
        raise NotImplementedError(
            "Subclass of CompiledProgram should implement _with_distributed method."
        )
X
polish  
Xin Pan 已提交
335

336
    def _compile_data_parallel(self, places, use_device, scope=None):
X
polish  
Xin Pan 已提交
337
        if self._share_vars_from:
338
            if scope:
X
polish  
Xin Pan 已提交
339 340
                sys.stderr.write("share_vars_from is set, scope is ignored.\n")
            if not self._share_vars_from._is_data_parallel:
341 342 343
                raise ValueError(
                    "The shared Program is not data parallel, cannot "
                    "share variables from it.")
X
polish  
Xin Pan 已提交
344 345
            if self._share_vars_from._executor is None:
                raise ValueError(
346 347
                    "The shared Program is not compiled and executed, so there is no "
                    "variables to share.")
X
polish  
Xin Pan 已提交
348 349
            self._local_scopes = self._share_vars_from._executor.local_scopes()
        else:
350
            assert scope is not None, ""
X
polish  
Xin Pan 已提交
351
            self._local_scopes = []
352

C
chengduo 已提交
353
        assert isinstance(places, tuple) or isinstance(places, list), \
354
            "Currently , The places type can only be list or tuple, but the input type is {}.".format(type(places))
C
chengduo 已提交
355 356 357 358 359 360 361

        if self._build_strategy is None:
            self._build_strategy = BuildStrategy()
        self._build_strategy.is_distribution = _is_pserver_mode(self._program)

        if self._exec_strategy is None:
            self._exec_strategy = ExecutionStrategy()
362
        self._exec_strategy._use_device = use_device
363 364

        if self._exec_strategy.num_threads == 0:
365
            if self._exec_strategy._use_device == DeviceType.CUDA:
366 367
                # Experiments on se-resnext shows that too many threads hurt
                # performance. Worth tunning for other models in the future.
C
chengduo 已提交
368
                self._exec_strategy.num_threads = len(places) * 4
369
            elif self._exec_strategy._use_device == DeviceType.XPU:
370 371
                # Currently only single thread is supported in Kunlun XPU.
                self._exec_strategy.num_threads = 1
372
            else:
C
chengduo 已提交
373 374
                self._exec_strategy.num_threads = len(places) * 2

375 376 377 378 379 380
        if "FLAGS_use_cinn" in core.globals() and core.globals(
        )["FLAGS_use_cinn"] and self._exec_strategy.num_threads != 1:
            warnings.warn("At present, when CINN is turned on, each process can " \
                  "only contain one thread, so reset the number of threads to 1 here.")
            self._exec_strategy.num_threads = 1

C
chengduo 已提交
381 382 383 384
        if self._build_strategy.num_trainers > 1:
            assert self._is_data_parallel, \
                "If you use multi-trainer to train the model, you should use "\
                "the data parallel model, i.e. calling with_data_parallel function."
385

X
Xin Pan 已提交
386 387
        # TODO(wuyi): trainer endpoings should be passed in through
        # build_strategy, not program.xxx.
388
        # TODO(gongwb): let user to set them once.
X
Xin Pan 已提交
389 390 391
        if self._program and self._build_strategy.num_trainers > 1 and \
                self._program._trainers_endpoints:
            tps = self._program._trainers_endpoints
D
dzhwinter 已提交
392

393
            assert self._build_strategy.num_trainers == len(
394
                tps), "The trainer numbers is not equal to endpoint numbers."
X
Xin Pan 已提交
395 396
            self._build_strategy.trainers_endpoints = tps

397 398
        if self._program:
            self._build_strategy.nccl_comm_num = self._program._nccl_comm_num
399 400
            self._build_strategy.use_hierarchical_allreduce = self._program._use_hierarchical_allreduce
            self._build_strategy.hierarchical_allreduce_inter_nranks = self._program._hierarchical_allreduce_inter_nranks
401

Q
qingqing01 已提交
402 403 404
        if self._build_strategy.sync_batch_norm:
            self._build_strategy.enable_sequential_execution = True

405
        if self._program is not None and self._program._enable_dgc:
406
            assert self._exec_strategy._use_device == DeviceType.CUDA, "DGC only used under CUDA environment."
407
            assert self._build_strategy.num_trainers * len(
408
                places) > 1, "DGC is not avaliable for single card training."
409
            assert self._build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "DGC \
410
                only can be used for AllReduce BuildStrategy."
411 412 413 414

            # DGC doesn't support fuse for now, close fuse.
            self._build_strategy.fuse_all_reduce_ops = False

X
Xin Pan 已提交
415
        self._persistable_vars = []
Z
Zhen Wang 已提交
416 417 418
        for node in self._graph.nodes():
            if node.is_var() and node.var() is not None and node.var().persistable() and \
                    node.var().type() != core.VarDesc.VarType.RAW:
419
                name = node.name()
420 421
                if self._program is not None and _should_broadcast_or_not_exists(
                        self._program, name):
422
                    self._persistable_vars.append(node.name())
423

C
chengduo 已提交
424 425
        places = list(map(_place_obj, places))

Y
Yan Xu 已提交
426 427 428 429 430 431
        # ParallelExecutor would broadcast all the parameters during initializing.
        # The parameters of each process should be in the same ordered for the data-parallelism
        # distributed training to keep the broadcast correct.
        self._persistable_vars = list(set(self._persistable_vars))
        self._persistable_vars.sort()

432 433 434 435 436
        return core.ParallelExecutor(places, self._persistable_vars,
                                     self._loss_name if self._loss_name else '',
                                     self._scope, self._local_scopes,
                                     self._exec_strategy, self._build_strategy,
                                     self._graph)
437

F
flame 已提交
438 439 440
    def _compile_inference(self):
        return core.create_paddle_predictor(self._infer_config)

441
    def _compile(self, scope, place):
X
Xin Pan 已提交
442 443 444 445 446 447 448 449 450 451
        """Compile the program based on the configs.

        Args:
            scope: The variables (resources) that are associated with
               this compiled program.
            place: The location that the compiled program will be run on.

        Returns:
            self
        """
452
        if self._compiled:
X
polish  
Xin Pan 已提交
453
            if scope and self._scope != scope:
454
                raise ValueError("Cannot compile program with different scope.")
S
sneaxiy 已提交
455
            if place and not self._place._equals(place):
456
                raise ValueError("Cannot compile program with different place.")
457
            return self
X
fix  
Xin Pan 已提交
458
        self._compiled = True
459 460 461

        self._scope = scope
        self._place = place
C
chengduo 已提交
462 463

        if self._is_inference:
F
flame 已提交
464
            self._executor = self._compile_inference()
465
        else:
C
chengduo 已提交
466 467 468 469
            if self._is_data_parallel:
                self._places = self._get_places(self._place, self._places)
            else:
                self._places = [self._place]
470 471 472 473 474 475 476 477 478

            # Todo(liym27):If optimizer is used in control flow,
            #  training on multi-places is not supported now, will
            #  be supported later.
            if len(self._places) > 1 and \
                    _has_optimizer_in_control_flow(self._program):
                raise NotImplementedError(
                    "If optimizer is used in control flow, "
                    "training on multi-places is not supported now.")
479
            if isinstance(self._place, core.CUDAPlace):
480
                use_device = DeviceType.CUDA
481
            elif isinstance(self._place, core.XPUPlace):
482
                use_device = DeviceType.XPU
483
            else:
484
                use_device = DeviceType.CPU
485 486 487
            self._executor = self._compile_data_parallel(use_device=use_device,
                                                         scope=self._scope,
                                                         places=self._places)
488
        return self
C
chengduo 已提交
489 490 491 492 493 494

    def _get_places(self, place, place_list):
        has_set_place = (place_list is not None)
        if has_set_place:
            for p in place_list:
                assert p._type() == place._type(), \
495
                    "Place type not match. You may set wrong type of places."
C
chengduo 已提交
496
        else:
497 498 499 500 501 502
            if isinstance(place, core.CUDAPlace):
                place_list = cuda_places()
            elif isinstance(place, core.XPUPlace):
                place_list = xpu_places()
            else:
                place_list = cpu_places()
503
        assert place_list, "No places for execution."
C
chengduo 已提交
504
        return place_list
J
jianghaicheng 已提交
505 506


507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635
class IpuDynamicPatcher(object):
    """
    Patcher for IPU dynamic2static support.
    """

    patcher_cache = []

    def __init__(self):
        pass

    @staticmethod
    def convert_concrete_program(ipu_strategy,
                                 concrete_program,
                                 class_instance=None):
        """
        Convert the ConcreteProgram to IPUConcreteProgram.
        """
        from ..fluid.dygraph.base import switch_to_static_graph
        from ..fluid import backward
        from ..fluid.initializer import Constant
        from ..fluid.framework import device_guard
        import paddle

        inputs = concrete_program.inputs
        outputs = concrete_program.outputs
        startup_program = concrete_program.startup_program

        scope = paddle.static.global_scope()

        @switch_to_static_graph
        def append_backward_desc():
            program = concrete_program.main_program

            # backward with optimizer to add backward graph to program
            backward.gradients_with_optimizer(program, ipu_strategy._optimizer)

            # initialize backward parameters
            exe = paddle.static.Executor(paddle.CPUPlace())
            startup_program = paddle.static.default_startup_program()
            exe.run(startup_program)

            return program

        if ipu_strategy.enable_fp16:
            class_instance.to(dtype="float16")

        # copy the bias and filters
        for param_or_buffer in concrete_program.parameters:
            param_or_buffer_tensor = scope.var(
                param_or_buffer.name).get_tensor()
            src_tensor = param_or_buffer.value().get_tensor()
            param_or_buffer_tensor._share_data_with(src_tensor)

        # TODO(czr): feed and fetch list needs to consider more type
        if class_instance:
            feed_list = [elem.name for elem in inputs[1:] if elem is not None]
        else:
            feed_list = [elem.name for elem in inputs if elem is not None]
        fetch_list = [elem.name for elem in outputs]

        if ipu_strategy.is_training:
            concrete_program.main_program = append_backward_desc()
            # copy optimizer parameters
            optimizer = ipu_strategy._optimizer
            for k, v in optimizer._accumulators.items():
                for param_name, var_tmp in v.items():
                    var = optimizer.helper.create_global_variable(
                        name=var_tmp.name,
                        persistable=True,
                        dtype=var_tmp.dtype,
                        type=var_tmp.type,
                        shape=var_tmp.shape,
                        belong_to_optimizer=True)
                    device = optimizer._get_device_for_param(param_name)
                    with device_guard(device):
                        optimizer.helper.set_variable_initializer(
                            var, initializer=Constant(value=0.0))
                    param_or_lr_tensor = scope.find_var(
                        var_tmp.name).get_tensor()
                    optim_tensor = var.value().get_tensor()
                    param_or_lr_tensor._share_data_with(optim_tensor)
                    optimizer._accumulators[k][param_name] = var

        @switch_to_static_graph
        def func_compile():
            if ipu_strategy.enable_fp16:
                amp_list = paddle.static.amp.CustomOpLists()
                amp_list.unsupported_list = {"cumsum"}
                to_fp16_var_names = paddle.static.amp.cast_model_to_fp16(
                    concrete_program.main_program,
                    amp_list,
                    use_fp16_guard=False)
                paddle.static.amp.cast_parameters_to_fp16(
                    paddle.CPUPlace(),
                    concrete_program.main_program,
                    to_fp16_var_names=to_fp16_var_names)

            program = IpuCompiledProgram(concrete_program.main_program,
                                         ipu_strategy=ipu_strategy,
                                         scope=scope).compile(
                                             feed_list, fetch_list)
            return program

        main_program = func_compile()
        concrete_program.main_program = main_program
        return concrete_program

    @staticmethod
    def patch_program_cache(ipu_strategy):
        """ Monkey patch ProgramCache discriptor to support dynamic2static in IPU.

        Args:
            ipu_strategy: The ipu_strategy used in dynamic graph.

        Returns:
            None
        """
        from ..fluid.dygraph.dygraph_to_static.program_translator import ProgramCache
        from ..fluid.dygraph.dygraph_to_static.program_translator import CacheKey
        from ..fluid.dygraph.dygraph_to_static import logging_utils
        from ..fluid.dygraph.dygraph_to_static.program_translator import MAX_TRACED_PROGRAM_COUNT
        from ..fluid.dygraph.dygraph_to_static.partial_program import partial_program_from

        old_getter = ProgramCache.__getitem__

        def patch_getter(self, item):
            if not isinstance(item, CacheKey):
                raise ValueError(
                    'type(item) should be CacheKey, but received %s' %
A
Allen Guo 已提交
636
                    type(item).__name__)
637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692
            item_id = hash(item)
            self._recent_key = item_id
            if item_id not in self._caches or ipu_strategy.need_compile:
                if item_id in self._caches:
                    logging_utils.warn(
                        "ipu_strategy chances detected. Please sync weights.")
                if self._caches and not ipu_strategy.need_compile:
                    logging_utils.warn(
                        "dynamic2static on IPU doesn't support mutiple caches. Please make sure"
                        "dynamic inputs is not used.")
                concrete_program, _ = self._build_once(item)
                concrete_program = IpuDynamicPatcher.convert_concrete_program(
                    ipu_strategy, concrete_program, item.class_instance)

                self._caches[item_id] = (concrete_program,
                                         partial_program_from(concrete_program))
                # Note: raise warnings if number of traced program is more than `max_tracing_count`
                current_tracing_count = len(self._caches)
                if current_tracing_count > MAX_TRACED_PROGRAM_COUNT:
                    logging_utils.warn(
                        "Current traced program number: {} > `max_tracing_count`:{}. Too much cached programs will bring expensive overhead. "
                        "The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors."
                        .format(current_tracing_count,
                                MAX_TRACED_PROGRAM_COUNT))

            return self._caches[item_id]

        setattr(ProgramCache, '__getitem__', patch_getter)
        IpuDynamicPatcher.patcher_cache.append(
            [ProgramCache, '__getitem__', old_getter])

    @staticmethod
    def patch_lr_scheduler(ipu_strategy):
        from paddle.optimizer.lr import LRScheduler
        # For IPU dynamic graph usage, lr_var is not synced in executor as static mode do.
        # Manually set lr to ipu_strategy to update the lr.
        old_step = LRScheduler.step

        def patch_step(self, epoch=None):
            old_step(self, epoch)
            ipu_strategy.set_options({"lr": self.last_lr})

        setattr(LRScheduler, 'step', patch_step)
        IpuDynamicPatcher.patcher_cache.append([LRScheduler, 'step', old_step])

    @staticmethod
    def register_patch(ipu_strategy):
        IpuDynamicPatcher.patch_program_cache(ipu_strategy)
        IpuDynamicPatcher.patch_lr_scheduler(ipu_strategy)

    @staticmethod
    def release_patch():
        for module, key, attr in IpuDynamicPatcher.patcher_cache:
            setattr(module, key, attr)


J
jianghaicheng 已提交
693 694 695 696 697 698 699 700 701
class IpuStrategy(object):
    """
    Help users precisely control the graph building in :code:`paddle.static.IpuCompiledProgram` .

    Returns:
        The IpuStrategy instance.

    Examples:
        .. code-block:: python
702

J
jianghaicheng 已提交
703 704 705 706 707 708
            # required: ipu

            import paddle
            import paddle.static as static

            paddle.enable_static()
709

J
jianghaicheng 已提交
710 711 712 713 714 715
            ipu_strategy = static.IpuStrategy()
    """

    def __init__(self):
        if core.is_compiled_with_ipu():
            self._ipu_strategy = core.IpuStrategy()
716 717 718 719 720 721 722 723 724 725 726 727 728
            default_options = {
                'location_optimizer': {
                    'on_chip': 0,
                    'use_replicated_tensor_sharding': 1,
                },  # set optimizer location
                'accumulation_and_replication_reduction_type':
                1,  # popart::ReductionType::Mean
                'mean_accumulation_and_replication_reduction_strategy':
                1,  # popart::MeanReductionStrategy::Post
            }
            self._ipu_strategy.set_options(default_options)
            self.has_custom_ops = False
            self.custom_op_names = []
729
            self.need_compile = True
J
jianghaicheng 已提交
730 731 732 733
        else:
            raise RuntimeError(
                "Can not use IpuStrategy in non IPU compiled environment, please re-compile with WITH_IPU=ON."
            )
734 735 736 737 738 739 740 741 742 743 744
        from paddle import in_dynamic_mode
        if in_dynamic_mode():
            self.register_patch()

    def register_patch(self):
        """
        Register patchs function to support dynamic to static on IPU. This operation would break the dy2static functionality on CPU.
        Use `release_patch` to release the patch.

        Examples:
            .. code-block:: python
745

746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762
                # required: ipu

                import paddle
                import paddle.static as static

                ipu_strategy = static.IpuStrategy()

                ipu_strategy.register_patch()
        """
        IpuDynamicPatcher.register_patch(self)

    def release_patch(self):
        """
        Release the registered IPU functions.

        Examples:
            .. code-block:: python
763

764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
                # required: ipu

                import paddle
                import paddle.static as static

                ipu_strategy = static.IpuStrategy()

                ipu_strategy.release_patch()
        """
        IpuDynamicPatcher.release_patch()

    def set_optimizer(self, optimizer):
        """
        Set optimizer to ipu_strategy in dynamic mode.

          Args:
              optimizer (Optimizer): Optimizer to be used in training.
781

782 783 784 785 786
          Returns:
              None.

          Examples:
              .. code-block:: python
787

788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812
                  # required: ipu

                  import paddle
                  import paddle.static as static

                  linear = paddle.nn.Linear(10, 10)
                  optimizer = paddle.optimizer.SGD(learning_rate=0.01,
                                                   parameters=linear.parameters())
                  ipu_strategy = static.IpuStrategy()
                  ipu_strategy.set_optimizer(optimizer)
        """
        from paddle import in_dynamic_mode
        if in_dynamic_mode():
            self._optimizer = optimizer
            optimizer_attrs = self.parse_optimizer(optimizer)
            self._ipu_strategy.set_options(optimizer_attrs)
        else:
            raise RuntimeError("Only needs to set optimizer in dynamic mode.")

    def parse_optimizer(self, optimizer):
        """
        Parse optimizer attributes for IPU dynamic to static support. Currently only support parse lr.

          Args:
              optimizer (Optimizer): Optimizer to be parsed.
813

814 815 816 817 818
          Returns:
              Dict.

          Examples:
              .. code-block:: python
819

820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843
                  # required: ipu

                  import paddle
                  import paddle.static as static

                  linear = paddle.nn.Linear(10, 10)
                  optimizer = paddle.optimizer.SGD(learning_rate=0.01,
                                                   parameters=linear.parameters())
                  ipu_strategy = static.IpuStrategy()
                  attrs = ipu_strategy.parse_optimizer(optimizer)
        """

        def get_lr():
            from paddle.optimizer.lr import LRScheduler
            if isinstance(optimizer._learning_rate, float):
                return {"lr": optimizer._learning_rate}
            elif isinstance(optimizer._learning_rate, LRScheduler):
                return {"lr": optimizer._learning_rate()}

        attr_fn = [get_lr]
        optimizer_attrs = {"is_dynamic": True}
        for fn in attr_fn:
            optimizer_attrs.update(fn())
        return optimizer_attrs
J
jianghaicheng 已提交
844

845 846 847
    def set_graph_config(self,
                         num_ipus=1,
                         is_training=True,
A
Allen Guo 已提交
848
                         micro_batch_size=1,
849
                         enable_manual_shard=False):
J
jianghaicheng 已提交
850 851 852 853 854 855 856 857
        """
        Set graph configuration to the IpuStrategy instance.

        Args:
            num_ipus (int, optional): Number of IPU devices. Default 1, which means only use 1 IPU.
            is_training (bool, optional): True is training graph, False is inference graph. Default True, which means is training mode.
            batch_size (int, optional): The batch-size in the graph. Used to make the graph batch-size fixed,
                if the batch-size in the graph is dynamic. Default 1, which means the batch-size would be set 1, if the batch-size is dynamice.
858 859 860
            enable_manual_shard (bool, optional): Enable graph sharding or not. Only if num_ipus > 1, enable_manual_shard is able to be set True.
                Default False, which means disabled.

J
jianghaicheng 已提交
861 862 863 864 865
        Returns:
            None.

        Examples:
            .. code-block:: python
866

J
jianghaicheng 已提交
867 868 869 870 871 872
                # required: ipu

                import paddle
                import paddle.static as static

                paddle.enable_static()
873

J
jianghaicheng 已提交
874
                ipu_strategy = static.IpuStrategy()
875
                ipu_strategy.set_graph_config(num_ipus=1,
J
jianghaicheng 已提交
876
                                            is_training=True,
A
Allen Guo 已提交
877
                                            micro_batch_size=1,
878
                                            enable_manual_shard=False)
J
jianghaicheng 已提交
879
        """
880
        if num_ipus == 1 and enable_manual_shard:
J
jianghaicheng 已提交
881 882 883
            raise RuntimeError(
                "Only if num_ipus > 1, enable_manual_shard is able to be set True."
            )
884 885 886
        options = {
            'num_ipus': num_ipus,
            'is_training': is_training,
A
Allen Guo 已提交
887
            'micro_batch_size': micro_batch_size,
888 889 890 891 892 893 894
            'enable_manual_shard': enable_manual_shard,
        }
        self.set_options(options)

    def set_pipelining_config(self,
                              enable_pipelining=False,
                              batches_per_step=1,
A
Allen Guo 已提交
895
                              enable_gradient_accumulation=False,
896
                              accumulation_factor=1):
J
jianghaicheng 已提交
897 898 899 900
        """
        Set pipelining configuration to the IpuStrategy instance. Used to optimize the throughput performance.

        Args:
901
            enable_pipelining (bool, optional): Enable data pipelining between subgraphs. Only if enable_manual_shard=True, enable_pipelining is able to be set True.
J
jianghaicheng 已提交
902 903 904
                Default False, which means disabled.
            batches_per_step (int, optional): Set the batches per run in data pipelining mode. Only if enable_pipelining=True, batches_per_step is able to be set > 1.
                Default 1, which means no data pipelining.
A
Allen Guo 已提交
905
            enable_gradient_accumulation (bool, optional): Enable to accumulate gradients before updating the weights in training mode. Only if enable_pipelining=True,
906 907
                enable_gradient_accumulation is able to be set True. Default False, which means no gradient accumulation.
            accumulation_factor (int, optional): Specify the number of micro-batches to accumulate
J
jianghaicheng 已提交
908
                before applying the varUpdate. Default 1, which means disable the accumulation.
909

J
jianghaicheng 已提交
910 911 912 913 914 915 916 917 918 919 920 921 922 923
        Returns:
            None.

        Examples:
            .. code-block:: python

                # required: ipu

                import paddle
                import paddle.static as static

                paddle.enable_static()

                ipu_strategy = static.IpuStrategy()
924 925
                ipu_strategy.set_pipelining_config(enable_pipelining=False,
                                                    batches_per_step=1,
A
Allen Guo 已提交
926
                                                    enable_gradient_accumulation=False,
927
                                                    accumulation_factor=1)
J
jianghaicheng 已提交
928
        """
929 930
        enable_manual_shard = self.get_option('enable_manual_shard')
        if not enable_manual_shard and enable_pipelining:
J
jianghaicheng 已提交
931 932 933
            raise RuntimeError(
                "Only if enable_manual_shard=True, enable_pipelining is able to be set True."
            )
934 935 936
        options = {
            'enable_pipelining': enable_pipelining,
            'batches_per_step': batches_per_step,
A
Allen Guo 已提交
937
            'enable_gradient_accumulation': enable_gradient_accumulation,
938 939 940 941 942
            'accumulation_factor': accumulation_factor,
        }
        self.set_options(options)

    def set_precision_config(self, enable_fp16=False):
J
jianghaicheng 已提交
943 944 945 946 947
        """
        Set half computation configuration to the IpuStrategy instance. Used to optimize the performance.

        Args:
            enable_fp16 (bool, optional): Enable FLOAT16 mode and transform FLOAT32 to FLOAT16. Default False, which means disable FLOAT16 mode.
948

J
jianghaicheng 已提交
949 950 951 952 953 954 955 956 957 958 959 960 961 962
        Returns:
            None.

        Examples:
            .. code-block:: python

                # required: ipu

                import paddle
                import paddle.static as static

                paddle.enable_static()

                ipu_strategy = static.IpuStrategy()
963 964
                ipu_strategy.set_precision_config(enable_fp16=False)
        """
965 966 967
        options = {
            'enable_fp16': enable_fp16,
        }
968 969 970 971 972 973 974
        self.set_options(options)

    def add_custom_op(self,
                      paddle_op,
                      popart_op=None,
                      domain='custom.ops',
                      version=1):
J
jianghaicheng 已提交
975
        """
976
        Add a mapping to use popart custom ops running on the IPU.
J
jianghaicheng 已提交
977

978 979
        Args:
            paddle_op(str): the name of custom op in paddle.
J
jianghaicheng 已提交
980

981
            popart_op(str): the name of custom op in popart.
J
jianghaicheng 已提交
982

983
            domain(str): domain name of custom op in popart.
J
jianghaicheng 已提交
984

985
            version(int): version of custom op in popart.
986

987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001
        Returns:
            None.

        Examples:
            .. code-block:: python

                # required: ipu

                import paddle
                import paddle.static as static

                paddle.enable_static()

                ipu_strategy = static.IpuStrategy()
                ipu_strategy.add_custom_op('paddle_relu', 'popart_relu')
J
jianghaicheng 已提交
1002
        """
1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
        if popart_op is None:
            popart_op = paddle_op
        custom_op = {
            'paddle_op': paddle_op,
            'popart_op': popart_op,
            'domain': domain,
            'version': version,
        }
        self.set_options({'custom_op': custom_op})
        self.custom_op_names.append(paddle_op)
        if not self.has_custom_ops:
            self.has_custom_ops = True

    def set_options(self, options):
J
jianghaicheng 已提交
1017
        """
1018
        Set options from dict.
J
jianghaicheng 已提交
1019

1020 1021
        Args:
            options(dict): dict of options.
1022

1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038
        Returns:
            None.

        Examples:
            .. code-block:: python

                # required: ipu

                import paddle
                import paddle.static as static

                paddle.enable_static()

                ipu_strategy = static.IpuStrategy()
                options = {'num_ipus':1, 'enable_fp16': True}
                ipu_strategy.set_options(options)
J
jianghaicheng 已提交
1039
        """
1040
        self._ipu_strategy.set_options(options)
1041 1042 1043 1044
        # check whether to recompile program with updated ipu options.
        recompile_white_list = {'lr'}
        if options.keys() - recompile_white_list:
            self.need_compile = True
J
jianghaicheng 已提交
1045

1046
    def get_option(self, option):
J
jianghaicheng 已提交
1047
        """
1048 1049 1050 1051
        Get option.

        Args:
            option(str): name of option.
1052

1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
        Returns:
            option value.

        Examples:
            .. code-block:: python

                # required: ipu

                import paddle
                import paddle.static as static

                paddle.enable_static()

                ipu_strategy = static.IpuStrategy()
                num_ipus = ipu_strategy.get_option('num_ipus')
J
jianghaicheng 已提交
1068
        """
1069
        return self._ipu_strategy.get_option(option)['value']
J
jianghaicheng 已提交
1070

A
Allen Guo 已提交
1071 1072 1073 1074 1075 1076
    def enable_pattern(self, pattern):
        """
        Enable PopART pattern to optimize the graph.

        Args:
            pattern(string): the name of the pattern.
1077

A
Allen Guo 已提交
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101
        Returns:
            None.

        Examples:
            .. code-block:: python

                # required: ipu

                import paddle
                import paddle.static as static

                paddle.enable_static()

                ipu_strategy = static.IpuStrategy()
                ipu_strategy.enable_pattern("ViewSimplifyPattern")
        """
        self._ipu_strategy.enable_pattern(pattern)

    def disable_pattern(self, pattern):
        """
        Disable PopART pattern.

        Args:
            pattern(string): the name of the pattern.
1102

A
Allen Guo 已提交
1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120
        Returns:
            None.

        Examples:
            .. code-block:: python

                # required: ipu

                import paddle
                import paddle.static as static

                paddle.enable_static()

                ipu_strategy = static.IpuStrategy()
                ipu_strategy.disable_pattern("ViewSimplifyPattern")
        """
        self._ipu_strategy.disable_pattern(pattern)

J
jianghaicheng 已提交
1121
    @property
1122
    def num_ipus(self):
J
jianghaicheng 已提交
1123
        """
1124
        Get the number of IPU devices from IpuStrategy instance.
J
jianghaicheng 已提交
1125
        """
1126
        return self.get_option('num_ipus')
J
jianghaicheng 已提交
1127 1128

    @property
1129
    def is_training(self):
J
jianghaicheng 已提交
1130
        """
1131
        Get the boolean of training or inference from IpuStrategy instance.
J
jianghaicheng 已提交
1132
        """
1133
        return self.get_option('is_training')
J
jianghaicheng 已提交
1134 1135

    @property
1136
    def enable_pipelining(self):
J
jianghaicheng 已提交
1137
        """
1138
        Get the boolean of enable pipelining or not from IpuStrategy instance.
J
jianghaicheng 已提交
1139
        """
1140
        return self.get_option('enable_pipelining')
J
jianghaicheng 已提交
1141 1142 1143 1144 1145 1146

    @property
    def enable_fp16(self):
        """
        Get the boolean of float16 mode or not from IpuStrategy instance.
        """
1147
        return self.get_option('enable_fp16')
J
jianghaicheng 已提交
1148 1149 1150 1151 1152 1153 1154 1155 1156


class IpuCompiledProgram(object):
    """
    The IpuCompiledProgram is used to transform a program to a ipu-target program,
    such as forward graph extraction, computing graph transformation, useless scale Ops clean, etc.

    Args:
        program(Program, optional): This parameter represents the :code:`Program`
1157
            to be executed. Default is None, which means the program will be set to
J
jianghaicheng 已提交
1158 1159
            the default program :code:`paddle.static.default_main_program()` .
        scope(Scope, optional): The scope used to run this program, you can switch
1160
            it to different scope. Default is None, which means use the global
J
jianghaicheng 已提交
1161 1162 1163
            scope :code:`paddle.static.global_scope()` .
        ipu_strategy(IpuStrategy, optional): This argument is used to build the program with the
            specified options, such as half computation, training or inference session, the number of IPUs, etc.
1164
            Default is None, which means build the program based on the default `ipu_strategy`.
J
jianghaicheng 已提交
1165 1166 1167 1168 1169 1170

    Returns:
        IpuCompiledProgram

    Example:
        .. code-block:: python
1171

J
jianghaicheng 已提交
1172 1173 1174 1175 1176 1177 1178 1179 1180 1181
            # required: ipu

            import paddle
            import paddle.static as static

            paddle.enable_static()

            a = static.data(name='data', shape=[None, 1], dtype='int32')
            b = a + 1
            main_prog = static.default_main_program()
1182

J
jianghaicheng 已提交
1183
            ipu_strategy = static.IpuStrategy()
A
Allen Guo 已提交
1184 1185
            ipu_strategy.set_graph_config(num_ipus=1, is_training=True, micro_batch_size=1)
            ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, enable_gradient_accumulation=False, accumulation_factor=1)
1186
            ipu_strategy.set_precision_config(enable_fp16=False)
1187

J
jianghaicheng 已提交
1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199
            ipu_compiled_program = static.IpuCompiledProgram(
                main_prog,
                ipu_strategy=ipu_strategy)
    """

    def __init__(self, program=None, scope=None, ipu_strategy=None):
        if not core.is_compiled_with_ipu():
            raise ValueError(
                "Can not use this function since PaddlePaddle is not compiled with IPU"
            )

        if program is None:
1200
            program = framework.default_main_program()
J
jianghaicheng 已提交
1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212

        if not isinstance(program, framework.Program):
            raise TypeError(
                "The type of program is wrong, expected Program, but got %s" %
                type(program))

        self._program = program
        self._compiled = False

        if scope is not None:
            self._scope = scope
        else:
1213 1214
            # import here to avoiding confused
            import paddle
J
jianghaicheng 已提交
1215 1216 1217
            self._scope = paddle.static.global_scope()

        if ipu_strategy is not None:
1218
            self._ipu_strategy = ipu_strategy
J
jianghaicheng 已提交
1219
        else:
1220
            self._ipu_strategy = IpuStrategy()
J
jianghaicheng 已提交
1221

1222 1223 1224 1225 1226 1227
        if ipu_strategy.has_custom_ops:
            self._custom_op_names = set(ipu_strategy.custom_op_names)
        else:
            self._custom_op_names = ()

        self._backend = core.IpuBackend.get_instance()
J
jianghaicheng 已提交
1228 1229 1230 1231 1232

    def compile(self, feed_list, fetch_list):
        """
        This interface is used to compile the input Program to a program
        to run the model on the ipu.
1233

J
jianghaicheng 已提交
1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
        Args:
            feed_list(list): This parameter represents the input Tensors of the model.

            fetch_list(list): This parameter represents the Tensors that need to be returned
                after the model.

        Returns:
            Program

        Example:
            .. code-block:: python
1245

J
jianghaicheng 已提交
1246
                # required: ipu
1247

J
jianghaicheng 已提交
1248 1249
                import paddle
                import paddle.static as static
1250

J
jianghaicheng 已提交
1251
                paddle.enable_static()
1252

J
jianghaicheng 已提交
1253 1254 1255 1256 1257
                a = static.data(name='data', shape=[None, 1], dtype='int32')
                b = a + 1
                main_prog = static.default_main_program()

                ipu_strategy = static.IpuStrategy()
A
Allen Guo 已提交
1258 1259
                ipu_strategy.set_graph_config(num_ipus=1, is_training=True, micro_batch_size=1)
                ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, enable_gradient_accumulation=False, accumulation_factor=1)
1260
                ipu_strategy.set_precision_config(enable_fp16=False)
1261

J
jianghaicheng 已提交
1262 1263 1264 1265
                program = static.IpuCompiledProgram(
                    main_prog,
                    ipu_strategy=ipu_strategy).compile([a.name], [b.name])
        """
1266 1267 1268
        self._backend.set_scope(self._scope)
        self._backend.set_ipu_strategy(self._ipu_strategy._ipu_strategy)

J
jianghaicheng 已提交
1269 1270 1271 1272 1273
        # feed and fetch doesn't have corresponding popart op, so we rm both here
        global_block = self._program.global_block()
        need_to_remove_op_index = []
        for i, op in enumerate(global_block.ops):
            op.desc.set_is_target(False)
1274
            if op.type == 'feed' or op.type == 'fetch':
J
jianghaicheng 已提交
1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286
                need_to_remove_op_index.append(i)

        for index in need_to_remove_op_index[::-1]:
            global_block._remove_op(index)

        for var in ['feed', 'fetch']:
            if global_block.has_var(var):
                global_block._remove_var(var)

        self._program.desc.flush()
        self._graph = core.Graph(self._program.desc)

1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322
        if self._ipu_strategy.is_training:
            passes = [
                'optimizer_extract_pass',
                'optimizer_state_align_pass',
            ]
            for pass_name in passes:
                a_pass = core.get_pass(pass_name)
                a_pass.apply(self._graph)

        passes = [
            'forward_graph_extract_pass',
            'infer_shape_pass',
            'avg_shard_pass',
            'delete_scale_op_pass',
        ]
        for pass_name in passes:
            a_pass = core.get_pass(pass_name)
            if pass_name == 'infer_shape_pass':
                a_pass.set('feed_list', feed_list)
            a_pass.apply(self._graph)

        a_pass = core.get_pass('popart_canonicalization_pass')
        if self._custom_op_names:
            a_pass.set('custom_ops', self._custom_op_names)
        a_pass.apply(self._graph)

        passes = [
            'ipu_inplace_pass',
            'ipu_graph_builder_pass',
            'ipu_runtime_replacer_pass',
        ]
        for pass_name in passes:
            a_pass = core.get_pass(pass_name)
            a_pass.set('feed_list', feed_list)
            a_pass.set('fetch_list', fetch_list)
            a_pass.apply(self._graph)
J
jianghaicheng 已提交
1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351

        convert_pass = core.get_pass('graph_to_program_pass')
        desc = core.ProgramDesc()
        convert_pass.set_not_owned('program', desc)
        convert_pass.apply(self._graph)
        program = framework.Program._construct_from_desc(desc)

        if hasattr(self._program, 'lr_sheduler'):
            # how to share var between two different block ?
            lr_var_name = self._program.lr_sheduler._var_name

            program.lr_sheduler = self._program.lr_sheduler
            # Program.clone will clone lr_sheduler, so i set lr_var as
            # lr_sheduler attribute
            global_block = self._program.global_block()
            program.lr_sheduler.lr_var = global_block.vars[lr_var_name]

        # with popart, we need to support batches_per_step, what means
        # the shape of feed_var and feed_tensor(maybe numpy array) will
        # mismatch, so we set need_check_feed to False. Thus we can avoid
        # modify logic of run.
        program_global_block = program.global_block()
        for feed_name in feed_list:
            feed_var = program_global_block.var(feed_name)
            feed_var.desc.set_need_check_feed(False)

        if not hasattr(program, 'org_program'):
            program.org_program = self._program

1352 1353
        self._ipu_strategy.need_compile = False

J
jianghaicheng 已提交
1354
        return program