tracing.py 44.5 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
M
Megvii Engine Team 已提交
9
import collections
M
Megvii Engine Team 已提交
10 11
import contextlib
import functools
M
Megvii Engine Team 已提交
12
import itertools
13
import json
14
import os
15 16
import pickle
from typing import Any
M
Megvii Engine Team 已提交
17

M
Megvii Engine Team 已提交
18 19
import numpy as np

20
from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata
21 22
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
23
    TensorWeakRef,
24 25 26 27 28
    apply,
    set_tracing,
    skip_tracing,
    unset_tracing,
)
29 30 31 32 33 34
from ..core._imperative_rt.ops import (
    AssertEqual,
    CollectiveComm,
    RemoteRecv,
    RemoteSend,
)
35
from ..core._trace_option import set_symbolic_shape
36
from ..core._wrap import as_device
37
from ..core.ops.builtin import BatchNorm, OpDef
M
Megvii Engine Team 已提交
38 39
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
40
from ..core.tensor.utils import setscalar
41
from ..utils.naming import AutoNaming
42
from ..utils.profiler import is_profiling
43
from .dtr_config import DTRConfig
44
from .graph_opt_config import GraphOptimizationConfig
45
from .sublinear_memory_config import SublinearMemoryConfig
M
Megvii Engine Team 已提交
46 47


48 49 50 51
def _input_node_use_static_shape():
    return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None


M
Megvii Engine Team 已提交
52 53 54 55 56 57 58
class TraceMismatchError(RuntimeError):
    pass


active_trace = None


59 60 61 62 63 64 65
def is_tracing():
    if active_trace is None:
        return False
    else:
        return not skip_tracing


M
Megvii Engine Team 已提交
66 67 68
@contextlib.contextmanager
def exclude_from_trace():
    global skip_tracing
69
    if skip_tracing or (active_trace is None):
M
Megvii Engine Team 已提交
70 71 72 73
        yield
        return
    try:
        skip_tracing = True
74
        unset_tracing()
M
Megvii Engine Team 已提交
75 76 77 78 79
        if active_trace is not None:
            active_trace._begin_excluded_region()
        yield
    finally:
        skip_tracing = False
80
        set_tracing()
M
Megvii Engine Team 已提交
81 82 83 84 85


class TensorInfo:
    __slots__ = (
        # collected attributes
86
        "name",
M
Megvii Engine Team 已提交
87
        "external",
88 89 90
        "data_read",
        "shape_read",
        "value_read",
M
Megvii Engine Team 已提交
91 92 93
        "exported",
        "device",
        "dtype",
94
        "shape",
95
        "is_const",
M
Megvii Engine Team 已提交
96 97 98 99 100 101 102 103 104 105
        "bound_data",
        # resources for execution
        "varnode",
        "data_setter",
        "shape_reader",
        "value_reader",
        "data_reader",
    )

    def __init__(self):
106
        self.name = None
M
Megvii Engine Team 已提交
107
        self.exported = None
108 109 110
        self.data_read = None
        self.shape_read = None
        self.value_read = None
M
Megvii Engine Team 已提交
111 112 113 114 115 116 117 118
        self.bound_data = None

        self.data_setter = None
        self.shape_reader = None
        self.value_reader = None
        self.data_reader = None


119
_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
120 121


M
Megvii Engine Team 已提交
122
class trace:
123
    """Wraps a callable and provide:
124 125 126 127

    * tracing via :meth:`.trace` and :meth:`.dump`
    * accelerated evalutaion via :meth:`.__call__`

128 129 130 131 132 133 134 135 136 137
    Args:
        function: the function will be traced.
        symbolic: whether to apply symbolic execution for tracing. Default: False
        capture_as_const: capture global vars or closures as const value. Default: False
        sublinear_memory_config: configuration for sublinear memory optimization.
            If not None, it enables sublinear memory optimization with given setting.
        profiling: whether to profile compiled trace. Default: False
        opt_level: optimization level for compiling trace. Default: 2
        graph_opt_config: configuration for graph optimization. Default: None
        symbolic_shape: whether to use symbolic shape for tracing. Default: True
138 139
    """

M
Megvii Engine Team 已提交
140 141 142
    def __new__(cls, *args, **kwargs):
        if not args:
            return functools.partial(cls, **kwargs)
143
        return super().__new__(cls)
M
Megvii Engine Team 已提交
144

145 146 147 148 149 150
    def __init__(
        self,
        function,
        symbolic=False,
        capture_as_const=False,
        sublinear_memory_config: SublinearMemoryConfig = None,
151
        dtr_config: DTRConfig = None,
152
        profiling: bool = False,
153
        opt_level: int = 2,
154
        graph_opt_config: GraphOptimizationConfig = None,
155
        symbolic_shape: bool = True,
156
    ):
M
Megvii Engine Team 已提交
157 158 159
        self.__wrapped__ = function
        self._symbolic = symbolic
        self._capture_as_const = capture_as_const
160
        self._sublinear_memory_config = sublinear_memory_config
161
        self._dtr_config = dtr_config
162 163
        self._profiling = profiling
        self._profiler = None
164
        self._profiler2 = None
165
        self._graph_opt_level = opt_level
166
        self._graph_opt_config = graph_opt_config
167
        self._symbolic_shape = symbolic_shape
168
        self._output_handles = set()
M
Megvii Engine Team 已提交
169

170 171 172
        self._reset()

    def _reset(self):
M
Megvii Engine Team 已提交
173 174 175 176 177 178 179
        self._untraced = True
        self._tinfo = []  # handle -> TensorInfo
        self._seq = []
        self._pc = 0
        self._graph = None
        self._need_reset_nodes = None
        self._lazy_eval_graph = None
180
        self._lazy_eval_tensors = set()
181
        self._lazy_eval_links = None
182
        self._active_tensors = set()
M
Megvii Engine Team 已提交
183 184
        self._tensor_remaps = None
        self._inputs_to_restore = None
185 186
        self._arg_bindings = None
        self._kwarg_bindings = None
M
Megvii Engine Team 已提交
187 188
        self._output_bindings = None
        self._output_names = None
M
Megvii Engine Team 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202

    def _new_handle(self):
        handle = len(self._tinfo)
        info = TensorInfo()
        self._tinfo.append(info)
        return handle, info

    def _apply_op(self, op, args):
        assert not self._untraced
        # check against trace
        if self._pc >= len(self._seq):
            raise TraceMismatchError("trace should end here, but more op observed")
        record = self._seq[self._pc]
        op_, ihandles, ohandles = record
203
        if (isinstance(op_, str) and op_ == "Const") or (op != op_):
204
            raise TraceMismatchError("op different from last time")
M
Megvii Engine Team 已提交
205 206 207
        if len(ihandles) != len(args):
            raise TraceMismatchError("op input size different from last time")

208
        # check all inputs of crrent op
M
Megvii Engine Team 已提交
209 210 211 212
        for h, x in zip(ihandles, args):
            info = self._tinfo[h]
            if info.external:
                if (
213 214
                    x._compiled_info is not None
                    and not self._tinfo[x._mixin_handle].exported
M
Megvii Engine Team 已提交
215 216 217 218 219 220
                ):
                    raise TraceMismatchError(
                        "failed to capture: input was an external tensor "
                        "last time, got an internal tensor this time"
                    )
                if info.bound_data:
221
                    if x._compiled_info is not None:
M
Megvii Engine Team 已提交
222 223 224 225 226
                        raise TraceMismatchError(
                            "const capture violated: was an external tensor "
                            "last time, got an internal tensor this time"
                        )
                    if x._handle != info.bound_data._handle:
227
                        if not np.array_equal(x.numpy(), info.bound_data.numpy()):
M
Megvii Engine Team 已提交
228 229 230 231
                            raise TraceMismatchError(
                                "const capture violated: got "
                                "a different tensor this time"
                            )
M
Megvii Engine Team 已提交
232 233 234 235 236 237 238 239 240 241 242
                else:
                    if info.dtype != x.dtype:
                        raise TraceMismatchError(
                            "failed to capture: different dtype from last time"
                        )
                    if info.device != x.device:
                        raise TraceMismatchError(
                            "failed to capture: different device from last time"
                        )
                    info.data_setter.set_value(x._dev_tensor())
            else:
243
                if x._mixin_handle == -1:
244 245 246 247 248 249
                    if x._handle not in self._tensor_remaps:
                        raise TraceMismatchError(
                            "unexpected capture: trying to use an external tensor as "
                            "input, but that input was an internal tensor last time"
                        )
                    else:
250
                        x._mixin_handle = self._tensor_remaps[
251 252
                            x._handle
                        ]._CompiledTensorProxy__handle
253
                if x._mixin_handle != h:
254 255 256 257
                    raise TraceMismatchError(
                        "mis-wiring: input edge to an data flow "
                        "graph node is different from last time"
                    )
M
Megvii Engine Team 已提交
258 259

        self._pc += 1
260
        outputs = []
261
        for h in ohandles:
262
            info = self._tinfo[h]
263
            # generate output tensor and create compied info
264 265
            y = RawTensor(info.varnode)
            y._compiled_info = CompiledTensorProxy(h)
266
            y._mixin_handle = h
267
            outputs += [y]
268
            self._active_tensors.add(TensorWeakRef(y))
269
        self._output_handles.update(ohandles)
M
Megvii Engine Team 已提交
270 271
        return outputs

272
    def _apply_const(self, value, dtype, device):
273 274 275 276 277 278
        assert not self._untraced
        # check against trace
        if self._pc >= len(self._seq):
            raise TraceMismatchError("trace should end here, but more op observed")
        record = self._seq[self._pc]
        op_, ihandles, ohandles = record
279
        # Const op is represented by a str
280 281
        assert isinstance(op_, str) and op_ == "Const"

282 283 284 285 286 287 288 289 290 291
        expected = self._tinfo[ohandles[0]].bound_data.numpy()
        shape = value.shape
        if shape != expected.shape or dtype != expected.dtype:
            eq = False
        elif shape == ():
            eq = expected.item() == value.item()
        elif shape == (1,):
            eq = expected[0] == value[0]
        else:
            eq = np.all(value == expected)
292 293 294 295
        if not eq:
            raise TraceMismatchError(
                "const tensor violated: got a different tensor this time"
            )
296 297 298

        self._pc += 1
        (h,) = ohandles
299
        outputs = [self._tinfo[h].bound_data]
300 301
        return outputs

302
    # run in first step, record information for trace
M
Megvii Engine Team 已提交
303 304 305
    def _record_op(self, op, inputs, outputs):
        if skip_tracing:
            for x in inputs:
306
                h = getattr(x, "_mixin_handle", -1)
307
                if h >= 0:
308
                    self._tinfo[h].data = True
M
Megvii Engine Team 已提交
309 310 311 312
            return

        ihandles = []
        for x in inputs:
313
            h = getattr(x, "_mixin_handle", -1)
314
            if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
M
Megvii Engine Team 已提交
315
                h, info = self._new_handle()
316
                name = AutoNaming.gen_name(x)
317
                info.name = name
M
Megvii Engine Team 已提交
318 319 320
                info.external = True
                info.device = x.device
                info.dtype = x.dtype
321
                info.shape = x.shape
M
Megvii Engine Team 已提交
322
                if self._capture_as_const:
323 324 325
                    info.bound_data = RawTensor(
                        x.numpy(), x.dtype, x.device, False, name
                    )
M
Megvii Engine Team 已提交
326 327 328 329 330 331 332 333

            ihandles.append(h)

        ohandles = []
        for x in outputs:
            h, info = self._new_handle()
            ohandles.append(h)
            info.external = False
334 335
            x._mixin_handle = h
            x._recording = True
336
            x._trace_mixin_info = info
337
            self._active_tensors.add(TensorWeakRef(x))
338
            if self._symbolic:
339
                self._lazy_eval_tensors.add(TensorWeakRef(x))
M
Megvii Engine Team 已提交
340 341 342

        self._seq.append((op, tuple(ihandles), tuple(ohandles)))

343
    def _record_const(self, outputs):
344 345
        if skip_tracing:
            (x,) = outputs
346
            h = getattr(x, "_mixin_handle", -1)
347
            if h >= 0:
348
                self._tinfo[h].data_read = True
349 350 351 352 353 354 355 356 357 358 359
            return

        (x,) = outputs
        h, info = self._new_handle()
        ohandles = [h]
        info.external = True
        info.device = x.device
        info.dtype = x.dtype
        info.shape = x.shape
        info.bound_data = x
        info.is_const = True
360 361
        x._mixin_handle = h
        x._recording = True
362
        x._trace_mixin_info = info
363
        if self._symbolic:
364
            self._lazy_eval_tensors.add(TensorWeakRef(x))
365
        self._seq.append(("Const", tuple(), tuple(ohandles)))
366

367
    def _set_active(self, active: bool):
M
Megvii Engine Team 已提交
368
        global active_trace
369 370 371 372
        if active:
            if active_trace:
                raise NotImplementedError("sorry, not implemented: nested trace")
            active_trace = self
M
Megvii Engine Team 已提交
373
        else:
374 375 376 377 378 379
            assert active_trace is self
            active_trace = None

    def _init_trace(self, symbolic: bool):
        if symbolic:
            self._lazy_eval_graph = G.Graph()
380
            self._apply_graph_options(self._lazy_eval_graph)
381
            self._lazy_eval_links = ()
382 383

    def _take_escaped_tensors(self):
384
        escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
M
Megvii Engine Team 已提交
385
        self._active_tensors.clear()
386 387
        return escaped_tensors

388
    def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
389 390 391
        lazy_eval_tensors = [x() for x in lazy_eval_tensors]
        lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None]
        readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors]
392
        self._apply_graph_options(lazy_eval_graph)
393
        lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
394
        lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
395
        lazy_eval_graph.compile(*lazy_eval_links, *readers)
396 397
        self._execute_graph(lazy_eval_graph)
        lazy_eval_graph.wait()
398
        for r, x in zip(readers, lazy_eval_tensors):
399
            # get values from lazy_eval_graph and assign to lazy_eval tensor
400 401
            x._handle = RawTensor(r.op.get_value())._handle
            x._reset_varnode()
402 403 404 405

    @contextlib.contextmanager
    def _setup(self):
        interrupted = False
M
Megvii Engine Team 已提交
406

407
        def do_enter():
408
            set_tracing()
409
            self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
410 411 412 413 414 415
            self._set_active(True)
            if self._untraced:
                self._init_trace(self._symbolic)
            else:
                if self._graph is None:
                    self._compile()
416
                self._execute_graph(self._graph)
417 418 419 420 421

        def do_finalize():
            escaped_tensors = self._take_escaped_tensors()
            if self._untraced:
                for x in escaped_tensors:
422
                    if x():
423
                        info = self._tinfo[x()._mixin_handle]
424
                        info.data_read = True
425 426
                        x()._mixin_handle = -1
                        x()._recording = False
427 428
                if self._inputs_to_restore:
                    for x in self._inputs_to_restore:
429 430
                        x._mixin_handle = -1
                        x._recording = False
431 432 433
                if self._symbolic and (
                    self._lazy_eval_tensors or self._lazy_eval_links
                ):
434
                    # eval lazy eval tensors
435 436
                    self._lazy_eval(
                        self._lazy_eval_graph,
437
                        self._lazy_eval_tensors,
438 439
                        self._lazy_eval_links,
                    )
M
Megvii Engine Team 已提交
440 441
                    self._lazy_eval_graph = None
                    self._lazy_eval_tensors = None
442
                    self._lazy_eval_links = None
443 444 445 446 447 448
                self._untraced = False
            else:
                # compiled_tensor leaks
                if self._pc == len(self._seq):
                    for x in escaped_tensors:
                        try:
449
                            assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
450
                        except RuntimeError:
451 452 453 454 455 456
                            # TraceMismatchError thrown in do_exit
                            pass
                    self._graph.wait()
                    self._reset_exec_env()

            # reset status
M
Megvii Engine Team 已提交
457
            self._pc = 0
458 459
            self._tensor_remaps = None
            self._set_active(False)
460
            set_symbolic_shape(self._save_symbolic_shape)
461
            unset_tracing()
462 463

        def do_exit():
464
            unset_tracing()
465 466 467
            if not self._untraced and self._pc != len(self._seq):
                raise TraceMismatchError("premature end")
            if not self._symbolic or not self._untraced:
468
                # reset output tensors
469 470 471 472 473 474 475 476
                for x in self._active_tensors.copy():
                    strong_x = x()
                    if strong_x is not None:
                        strong_x._dev_tensor()
                        strong_x._reset_varnode()
                        strong_x._mixin_handle = -1
                        strong_x._recording = False
                        strong_x._trace_mixin_info = None
477 478 479 480 481 482 483 484 485 486 487 488

        try:
            do_enter()
            yield
            do_exit()
        except:
            interrupted = True
            raise
        finally:
            do_finalize()
            if interrupted:
                self._reset()
M
Megvii Engine Team 已提交
489 490

    def _begin_excluded_region(self):
M
Megvii Engine Team 已提交
491 492 493 494
        if self._capture_as_const:
            raise RuntimeError(
                "exclude_from_trace cannot be used with capture_as_const"
            )
M
Megvii Engine Team 已提交
495 496 497
        if self._untraced:
            # conditionally reading a compiled tensor in excluded region
            # is permitted, so we have to assume every tensor might be read
498 499 500 501
            for x in self._active_tensors:
                strong_x = x()
                if strong_x:
                    info = self._tinfo[strong_x._mixin_handle]
502 503 504
                    info.exported = True
                    info.data_read = True
        else:
505 506 507 508
            for x in self._active_tensors:
                strong_x = x()
                if strong_x:
                    strong_x._dev_tensor()
M
Megvii Engine Team 已提交
509

510 511
    def _apply_graph_options(self, graph):

512
        graph.options.no_force_inplace = True
513
        graph.options.seq_opt.enable_seq_comp_node_opt = False
514
        graph.options.graph_opt_level = self._graph_opt_level
515 516 517 518 519 520 521 522
        if self._dtr_config is not None:
            graph.options.enable_dtr_memory_opt = True
            graph.options.dtr_config.eviction_threshold = (
                self._dtr_config.eviction_threshold
            )
            graph.options.dtr_config.evictee_minimum_size = (
                self._dtr_config.evictee_minimum_size
            )
523 524 525 526 527 528 529 530
        # graph optimization
        if self._graph_opt_config is not None:
            mapping = {None: 0, False: 1, True: 2}
            jit_config = graph.options.graph_opt.jit_config
            jit_config.fuse_dimshuffle = mapping[
                self._graph_opt_config.jit_fuse_dimshuffle
            ]
            jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce]
531 532 533 534
        # sublinear
        if self._sublinear_memory_config is not None:
            graph.options.enable_sublinear_memory_opt = True
            sublinear_config = graph.options.sublinear_mem_config
535
            sublinear_config.lb_memory_mb = self._sublinear_memory_config.lb_memory_mb
536 537 538 539 540 541 542 543
            sublinear_config.genetic_nr_iter = (
                self._sublinear_memory_config.genetic_nr_iter
            )
            sublinear_config.genetic_pool_size = (
                self._sublinear_memory_config.genetic_pool_size
            )
            sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
            sublinear_config.num_worker = self._sublinear_memory_config.num_worker
544
        # profile
545 546
        if self._profiling:
            self._profiler = GraphProfiler(graph)
547
        self._profiler2 = None
548 549
        if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
            graph.options.var_sanity_check_first_run = False
550

551 552 553 554 555 556 557
    def _execute_graph(self, graph: G.Graph, *args):
        if is_profiling() and (self._profiler2 is None):
            self._profiler2 = GraphProfiler2(graph)
        elif not is_profiling() and (self._profiler2 is not None):
            self._profiler2 = None
        graph.execute(*args)

M
Megvii Engine Team 已提交
558 559
    def _compile(self):
        graph = self._graph = G.Graph()
560
        graph.options.async_exec_level = 0b100
561
        self._apply_graph_options(graph)
M
Megvii Engine Team 已提交
562 563
        need_reset_nodes = self._need_reset_nodes = []
        # links enforce ordering of I/O nodes
564 565
        in_out_links = ()
        io_links = ()
566
        readers = []
M
Megvii Engine Team 已提交
567 568

        if self._capture_as_const:
569
            for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
M
Megvii Engine Team 已提交
570 571
                info = self._tinfo[h]
                opnode = info.data_setter = G.InputNode(
572 573
                    device=info.device,
                    dtype=info.dtype,
574
                    shape=info.shape or (1,),
575 576
                    graph=graph,
                    use_static_shape=_input_node_use_static_shape(),
M
Megvii Engine Team 已提交
577 578 579
                )
                need_reset_nodes.append(opnode)
                info.varnode = opnode.outputs[0]
580
                in_out_links += opnode.outputs[1:]
M
Megvii Engine Team 已提交
581

M
Megvii Engine Team 已提交
582
        for op, ihandles, ohandles in self._seq:
583
            if isinstance(op, str) and op == "Const":
584 585 586 587 588 589 590 591 592 593 594 595
                assert len(ihandles) == 0
                (h,) = ohandles
                info = self._tinfo[h]
                if not hasattr(info, "varnode"):
                    assert info.external
                    assert info.bound_data
                    info.varnode = graph.make_const(
                        info.bound_data.numpy(),
                        info.bound_data.dtype,
                        info.bound_data.device,
                    )
                continue
596

597
            require_links = type(op) in _io_op_types
M
Megvii Engine Team 已提交
598
            ivars = []
599
            for i, h in enumerate(ihandles):
M
Megvii Engine Team 已提交
600 601 602 603
                info = self._tinfo[h]
                if not hasattr(info, "varnode"):
                    assert info.external
                    if info.bound_data:
604
                        if getattr(info, "is_const", False):
605 606 607 608 609 610 611 612 613 614
                            info.varnode = graph.make_const(
                                info.bound_data.numpy(),
                                info.bound_data.dtype,
                                info.bound_data.device,
                            )
                        else:
                            info.varnode = graph.make_const(
                                info.bound_data._dev_tensor()
                                # info.bound_data.numpy()
                            )
M
Megvii Engine Team 已提交
615 616
                    else:
                        opnode = info.data_setter = G.InputNode(
617
                            *in_out_links,
618 619
                            device=info.device,
                            dtype=info.dtype,
620
                            shape=info.shape or (1,),
621
                            graph=graph,
622
                            use_static_shape=_input_node_use_static_shape(),
M
Megvii Engine Team 已提交
623 624
                        )
                        need_reset_nodes.append(opnode)
625 626
                        info.varnode, *in_out_links = opnode.outputs
                if require_links and i == 0 and len(io_links) > 0:
627 628 629 630
                    opnode = G.VirtualDepNode(
                        [info.varnode, *io_links], str(io_links[0].device)
                    )
                    info.varnode = opnode.outputs[0]
631
                    io_links = (info.varnode,)
M
Megvii Engine Team 已提交
632 633

                ivars.append(info.varnode)
634

635
            ovars = G.apply_normal_varnode(op, *ivars)
636

637
            if require_links and len(ovars) > 0:
638
                io_links = (ovars[0],)
M
Megvii Engine Team 已提交
639 640 641 642 643 644
            assert len(ovars) == len(ohandles)
            for h, v in zip(ohandles, ovars):
                info = self._tinfo[h]
                info.varnode = v

                def add_reader(opnode):
645
                    nonlocal in_out_links
M
Megvii Engine Team 已提交
646 647
                    need_reset_nodes.append(opnode)
                    readers.append(opnode.outputs[0])
648
                    in_out_links = opnode.outputs
M
Megvii Engine Team 已提交
649

650
                if info.data_read:
M
Megvii Engine Team 已提交
651 652 653 654
                    # Shape can be obtained from data so doesn't need its own
                    # output node. On the other hand, value is read separately
                    # to leverage eager h2d copy
                    info.shape_read = False
655
                    opnode = info.data_reader = G.OutputNode(v, *in_out_links)
M
Megvii Engine Team 已提交
656 657
                    add_reader(opnode)
                if info.value_read:
658
                    opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
M
Megvii Engine Team 已提交
659 660
                    add_reader(opnode)
                if info.shape_read:
661
                    opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
M
Megvii Engine Team 已提交
662
                    add_reader(opnode)
663

664
        graph.options.graph_opt_level = self._graph_opt_level
665
        graph._set_priority_to_id([*readers, *in_out_links, *io_links])
666
        graph.compile(*readers, *in_out_links, *io_links)
M
Megvii Engine Team 已提交
667 668 669 670 671 672 673

    def _reset_exec_env(self):
        for opnode in self._need_reset_nodes:
            opnode.reset()

    def __call__(self, *args, **kwargs):
        with self._setup():
M
Megvii Engine Team 已提交
674 675 676
            if self._capture_as_const:
                self._process_inputs(*args, **kwargs)
            outputs = self.__wrapped__(*args, **kwargs)
677 678
            if self._capture_as_const:
                self._process_outputs(outputs)
M
Megvii Engine Team 已提交
679 680
            return outputs

681 682 683 684 685 686 687
    def dump(
        self,
        file,
        *,
        arg_names=None,
        output_names=None,
        append=False,
688 689 690 691 692 693
        keep_var_name: int = 1,
        keep_opr_name: bool = False,
        keep_param_name: bool = False,
        keep_opr_priority: bool = False,
        strip_info_file=None,
        append_json=False,
694
        optimize_for_inference=True,
695 696
        user_info: Any = None,
        enable_metadata: bool = True,
697 698
        **kwargs
    ):
699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766
        r"""Serializes trace to file system.

        Args:
            file: output file, could be file object or filename.
            arg_names: names of the input tensors in the traced function.
            output_names: names of the output tensors in the traced function,
                use the default name if not specified.
            append: whether output is appended to ``file``.
                Only works when ``file`` is str.
            keep_var_name: level for keeping variable names:

                * 0: none of the names are kept
                * 1: (default)keep names of output vars
                * 2: keep names of all (output and internal) vars

            keep_opr_name: whether to keep operator names.
            keep_param_name: whether to keep param names, so param values can be
                easily manipulated after loading model
            keep_opr_priority: whether to keep priority setting for operators
            strip_info_file: a string for path or a file handler. if is not None,
                then the dump information for code strip would be written to ``strip_info_file``
            append_json: will be check when `strip_info_file` is not None. if set
                true, the information for code strip will be append to strip_info_file.
                if set false, will rewrite strip_info_file
            optimize_for_inference: enbale optmizations,
                will skip all optimize options if this is False. Default: True
            user_info: any type object, which will be pickled to bytes.
            enable_metadata: whether to save metadata into output file.

        Keyword Arguments:

        * enable_io16xc32 --
          whether to use float16 for I/O between oprs and use
          float32 as internal computation precision. Note the output var would be
          changed to float16.
        * enable_ioc16 --
          whether to use float16 for both I/O and computation
          precision.
        * enable_hwcd4 --
          whether to use NHWCD4 data layout. This is faster on some
          OpenCL backend.
        * enable_nchw88 --
          whether to use NCHW88 data layout, currently
          used in X86 AVX backend.
        * enable_nchw44 --
          whether to use NCHW44 data layout, currently
          used in arm backend.
        * enable_nchw44_dot --
          whether to use NCHW44_dot data layout, currently
          used in armv8.2+dotprod backend.
        * enable_nchw4 --
          whether to use NCHW4 data layout, currently
          used in nvidia backend(based on cudnn).
        * enable_nchw32 --
          whether to use NCHW32 data layout, currently
          used in nvidia backend with tensorcore(based on cudnn).
        * enable_chwn4 --
          whether to use CHWN4 data layout, currently
          used in nvidia backend with tensorcore.
        * enable_nchw64 --
          whether to use NCHW64 data layout, used for fast int4
          support on Nvidia GPU.
        * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
          into one opr.
        * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
          input for inference on nvidia backend(this optimization pass will
          result in mismatch of the precision of output of training and
          inference)
767
        """
M
Megvii Engine Team 已提交
768 769 770 771 772 773 774 775 776 777
        if not self._capture_as_const:
            raise ValueError(
                "you must specify capture_as_const=True at __init__ to use dump"
            )
        if self._untraced:
            raise RuntimeError("should run at least once before calling dump")
        if self._output_names and output_names:
            raise TypeError(
                "cannot specify output_names when output is already in dict format"
            )
M
Megvii Engine Team 已提交
778
        if output_names and not isinstance(output_names, collections.abc.Sequence):
M
Megvii Engine Team 已提交
779 780
            output_names = (output_names,)
        if output_names and len(output_names) != len(self._output_bindings):
781 782 783 784 785
            raise ValueError(
                "wrong number of output_names, should be {} values".format(
                    len(self._output_bindings)
                )
            )
786 787
        without_arg_names = arg_names is None
        if without_arg_names:
788
            arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
M
Megvii Engine Team 已提交
789
        if arg_names and not isinstance(arg_names, collections.abc.Sequence):
M
Megvii Engine Team 已提交
790 791
            arg_names = (arg_names,)
        if arg_names and len(arg_names) != len(self._arg_bindings):
792 793 794 795 796
            raise ValueError(
                "wrong number of arg_names, should be {} values".format(
                    len(self._arg_bindings)
                )
            )
M
Megvii Engine Team 已提交
797 798
        output_names = output_names or self._output_names

799 800 801 802 803
        def dumped_device(info):
            device_name = info.device.logical_name
            if device_name[:3] in ("cpu", "gpu", "xpu"):
                return as_device("xpux")
            return info.device
804

M
Megvii Engine Team 已提交
805 806
        h2v = {}
        graph = G.Graph()
807 808 809 810

        # apply graph_opt_level in dump
        if self._graph_opt_level is not None:
            graph.options.graph_opt_level = self._graph_opt_level
811
        for i, h in enumerate(self._arg_bindings):
M
Megvii Engine Team 已提交
812
            info = self._tinfo[h]
813 814
            h2v[h] = graph.make_h2d(
                dtype=info.dtype,
815
                device=dumped_device(info),
816
                shape=info.shape or (1,),
817
                name=info.name if without_arg_names and info.name else arg_names[i],
818 819
            )
        for k, h in self._kwarg_bindings.items():
M
Megvii Engine Team 已提交
820
            info = self._tinfo[h]
821
            h2v[h] = graph.make_h2d(
822 823 824 825
                dtype=info.dtype,
                device=dumped_device(info),
                shape=info.shape or (1,),
                name=k,
826
            )
M
Megvii Engine Team 已提交
827 828

        for op, ihandles, ohandles in self._seq:
829
            if isinstance(op, str) and op == "Const":
830 831 832 833 834 835 836
                assert len(ihandles) == 0
                (h,) = ohandles
                info = self._tinfo[h]
                if h not in h2v:
                    assert info.external
                    assert info.bound_data
                    h2v[h] = graph.make_const(
837 838
                        info.bound_data.numpy(),
                        dtype=info.dtype,
839
                        device=dumped_device(info),
840
                        name=info.name,
841 842
                    )
                continue
M
Megvii Engine Team 已提交
843 844 845 846 847 848
            ivars = []
            for h in ihandles:
                info = self._tinfo[h]
                if h not in h2v:
                    assert info.external
                    assert info.bound_data
849
                    h2v[h] = graph.make_const(
850 851
                        info.bound_data.numpy(),
                        dtype=info.dtype,
852
                        device=dumped_device(info),
853
                        name=info.name,
854
                    )
M
Megvii Engine Team 已提交
855
                ivars.append(h2v[h])
856 857 858 859 860
            if isinstance(op, BatchNorm):
                assert (
                    op.fwd_mode == BatchNorm.FwdMode.INFERENCE
                ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
            ovars = G.apply_normal_varnode(op, *ivars)
861

862
            AutoNaming.record_opnode(ovars[0].op)
863

M
Megvii Engine Team 已提交
864 865 866
            assert len(ovars) == len(ohandles)
            h2v.update(zip(ohandles, ovars))

867
            for i in ohandles:
868
                name = AutoNaming.get_var_name(i)
869 870 871
                if name is not None:
                    h2v[i].name = name

872
        AutoNaming.remove_duplicate_names()
873

M
Megvii Engine Team 已提交
874 875 876 877 878 879 880
        dest_vars = []
        for i, h in enumerate(self._output_bindings):
            v = h2v[h]
            if output_names:
                v.name = output_names[i]
            dest_vars.append(v)

881
        if optimize_for_inference:
882 883 884 885 886 887 888 889 890
            dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)

        metadata = SerializationMetadata()
        if enable_metadata:
            metadata.user_info = pickle.dumps(user_info)
            metadata.is_valid = True
            metadata.graph_modified = False
            if optimize_for_inference:
                metadata.optimize_options = optimize_options
891

M
Megvii Engine Team 已提交
892
        if isinstance(file, str):
893 894
            permission = "wb" if append == False else "ab"
            file = open(file, permission)
895 896 897 898

        if keep_opr_priority:
            graph._set_priority_to_id(dest_vars)

899 900 901 902 903 904 905 906
        dump_content, dump_info = G.dump_graph(
            dest_vars,
            keep_var_name=keep_var_name,
            keep_opr_name=keep_opr_name,
            keep_param_name=keep_param_name,
            keep_opr_priority=keep_opr_priority,
            strip_info_file=strip_info_file,
            append_json=append_json,
907
            metadata=metadata,
908
        )
909 910
        file.write(dump_content)
        return dump_info
M
Megvii Engine Team 已提交
911 912 913 914 915 916 917 918 919 920

    def _process_inputs(self, *args, **kwargs):
        if self._untraced:
            self._inputs_to_restore = []

            def record_input(x):
                if x is None:
                    return
                h, info = self._new_handle()
                info.external = False
921
                info.name = x.c_name
M
Megvii Engine Team 已提交
922 923
                info.device = x.device
                info.dtype = x.dtype
924
                info.shape = x.numpy().shape
925 926
                x._mixin_handle = h
                x._recording = True
927
                x._trace_mixin_info = info
M
Megvii Engine Team 已提交
928 929 930
                self._inputs_to_restore.append(x)
                return h

931
            self._arg_bindings = []
M
Megvii Engine Team 已提交
932
            for i, x in enumerate(args):
933
                if not isinstance(x, RawTensor):
M
Megvii Engine Team 已提交
934 935 936 937
                    raise TypeError(
                        "positional arguments should all be tensor "
                        "but args[%d] cannot be recognized as one" % i
                    )
938
                self._arg_bindings.append(record_input(x))
M
Megvii Engine Team 已提交
939

940
            self._kwarg_bindings = {}
M
Megvii Engine Team 已提交
941
            for k, x in kwargs.items():
942
                if isinstance(x, RawTensor):
943
                    self._kwarg_bindings[k] = record_input(x)
M
Megvii Engine Team 已提交
944
        else:
945
            if len(args) != len(self._arg_bindings):
M
Megvii Engine Team 已提交
946 947 948 949
                raise TraceMismatchError("positional argument length mismatch")

            self._tensor_remaps = {}

950
            for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
951
                if not isinstance(x, RawTensor):
M
Megvii Engine Team 已提交
952 953 954 955 956 957 958 959 960 961
                    raise TypeError(
                        "positional arguments should all be tensor "
                        "but args[%d] cannot be recognized as one" % i
                    )
                info = self._tinfo[h]
                if x.dtype != info.dtype:
                    raise TypeError("args[%d].dtype different from last time" % i)
                if x.device != info.device:
                    raise TypeError("args[%d].device different from last time" % i)
                info.data_setter.set_value(x._dev_tensor())
962
                self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
M
Megvii Engine Team 已提交
963 964 965

            kwargs_tensors = {}
            for k, x in kwargs.items():
966
                if isinstance(x, RawTensor):
M
Megvii Engine Team 已提交
967
                    kwargs_tensors[k] = x
968 969 970
            if set(kwargs_tensors) != set(self._kwarg_bindings):
                too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
                too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
M
Megvii Engine Team 已提交
971 972 973 974 975 976 977 978 979 980
                if too_many:
                    raise TraceMismatchError(
                        "keyword arguments found to be tensor this time "
                        "but were non-tensor previously: %s" % " ".join(too_many)
                    )
                if too_few:
                    raise TraceMismatchError(
                        "keyword arguments found to be non-tensor this time "
                        "but were tensor previously: %s" % " ".join(too_few)
                    )
981
            for k, h in self._kwarg_bindings.items():
M
Megvii Engine Team 已提交
982 983 984 985 986 987 988
                x = kwargs_tensors[k]
                info = self._tinfo[h]
                if x.dtype != info.dtype:
                    raise TypeError("kwargs[%s].dtype different from last time" % k)
                if x.device != info.device:
                    raise TypeError("kwargs[%s].device different from last time" % k)
                info.data_setter.set_value(x._dev_tensor())
989
                self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
M
Megvii Engine Team 已提交
990 991 992

    def _process_outputs(self, outputs):
        output_names = None
M
Megvii Engine Team 已提交
993
        if isinstance(outputs, collections.abc.Mapping):
M
Megvii Engine Team 已提交
994
            output_names, outputs = zip(*sorted(outputs.items()))
M
Megvii Engine Team 已提交
995
        elif not isinstance(outputs, collections.abc.Sequence):
M
Megvii Engine Team 已提交
996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
            outputs = (outputs,)

        if not self._untraced:
            if output_names != self._output_names:
                too_many = set(output_names) - set(self._output_names)
                too_few = set(self._output_names) - set(output_names)
                if too_many:
                    raise TraceMismatchError(
                        "output has more keys than last time: %s" % " ".join(too_many)
                    )
                if too_few:
                    raise TraceMismatchError(
                        "output has less keys than last time: %s" % " ".join(too_few)
                    )
            if len(outputs) != len(self._output_bindings):
                raise TraceMismatchError("output size differs from last time")
        else:
            self._output_names = output_names
            self._output_bindings = []

        for i, x in enumerate(outputs):
1017
            if not isinstance(x, RawTensor):
M
Megvii Engine Team 已提交
1018 1019
                raise TypeError("every item of return value should be tensor")
            if self._untraced:
1020
                h = x._mixin_handle
1021
                if h < 0:
M
Megvii Engine Team 已提交
1022 1023 1024
                    raise RuntimeError("output is not computed from inputs")
                self._output_bindings.append(h)
            else:
1025
                h = x._mixin_handle
1026
                if h not in self._output_handles:
M
Megvii Engine Team 已提交
1027 1028 1029 1030 1031 1032
                    raise RuntimeError("output is not computed from inputs")
                if h != self._output_bindings[i]:
                    raise TraceMismatchError(
                        "retval[%s] is a different tensor than last time"
                        % (output_names and output_names[i] or i)
                    )
M
Megvii Engine Team 已提交
1033

1034
    def get_profile(self):
1035
        r"""Get profiling result for compiled trace.
1036

1037 1038
        Return:
            a json compatible object.
1039 1040 1041 1042 1043
        """
        if not self._profiler:
            raise RuntimeError("trace is not set with profiling=True")
        return json.loads(self._profiler.get())

M
Megvii Engine Team 已提交
1044

1045
class CompiledTensorProxy:
1046
    r"""Duck-typed RawTensor"""
M
Megvii Engine Team 已提交
1047 1048 1049

    def __init__(self, handle):
        self.__handle = handle
1050
        self._isscalar = False
M
Megvii Engine Team 已提交
1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
        self.__info = active_trace._tinfo[handle]
        self.__shape = None
        self.__data = None
        self.__value = None

    @property
    def dtype(self):
        return self.__info.varnode.dtype

    @property
    def device(self):
        return self.__info.varnode.device

    @property
    def shape(self):
1066 1067
        if self._isscalar:
            return ()
M
Megvii Engine Team 已提交
1068
        if self.__shape is None:
1069
            if self.__info.shape_read:
M
Megvii Engine Team 已提交
1070
                self.__shape = self.__info.shape_reader.get_value().shape
1071
            elif self.__info.data_read:
1072
                self.__shape = self._dev_tensor().shape
M
Megvii Engine Team 已提交
1073
            else:
1074 1075
                # c++ will throw TraceReadError
                return None
M
Megvii Engine Team 已提交
1076 1077 1078 1079
        return self.__shape

    def numpy(self):
        if self.__value is None:
1080
            if self.__info.value_read:
M
Megvii Engine Team 已提交
1081
                self.__value = self.__info.value_reader.get_value()
1082
            elif self.__info.data_read:
M
Megvii Engine Team 已提交
1083 1084
                self.__value = self._dev_tensor().numpy()
            else:
1085 1086
                # c++ will throw TraceReadError
                return None
1087
        # c++ side will handle scalar case
M
Megvii Engine Team 已提交
1088 1089 1090 1091
        return self.__value

    def _dev_tensor(self):
        if self.__data is None:
1092
            if not self.__info.data_read:
1093 1094
                # c++ will throw TraceReadError
                return None
M
Megvii Engine Team 已提交
1095 1096 1097 1098
            self.__data = self.__info.data_reader.get_value()
        return self.__data

    def __del__(self):
1099
        if self.__info.shape_read and self.__shape is not None:
M
Megvii Engine Team 已提交
1100
            self.__info.shape_reader.drop_value()
1101
        if self.__info.value_read and self.__value is not None:
1102
            self.__info.value_reader.drop_value()
1103
        if self.__info.data_read and self.__data is not None:
M
Megvii Engine Team 已提交
1104 1105 1106 1107
            self.__info.data_reader.drop_value()


def assign_raw_tensor(lhs, rhs):
1108
    lhs.__init__(rhs)
M
Megvii Engine Team 已提交
1109 1110 1111 1112


def apply_symbolic_mode(op: OpDef, *args: RawTensor):
    graph = active_trace._lazy_eval_graph
1113 1114
    ivars = []
    for x in args:
1115
        var = getattr(x, "_varnode", None)
1116 1117 1118 1119 1120 1121
        if var:
            ivars.append(var)
        else:
            data_setter = G.InputNode(
                device=x.device,
                dtype=x.dtype,
1122
                shape=x.numpy().shape or (1,),
1123 1124 1125 1126 1127 1128
                graph=graph,
                use_static_shape=True,
            )
            var = data_setter.outputs[0]
            ivars.append(var)
            data_setter.set_value(x._dev_tensor())
1129 1130 1131 1132 1133

    require_links = type(op) in _io_op_types

    if require_links and active_trace._lazy_eval_links:
        assert len(ivars) > 0, "op should has at least one input"
1134 1135 1136 1137 1138
        opnode = G.VirtualDepNode(
            [ivars[0], *active_trace._lazy_eval_links],
            str(active_trace._lazy_eval_links[0].device),
        )
        ivars[0] = opnode.outputs[0]
1139 1140
        active_trace._lazy_eval_links = (ivars[0],)

1141
    ovars = G.apply_normal_varnode(op, *ivars)
1142
    outputs = [RawTensor(o) for o in ovars]
1143 1144

    if require_links:
1145
        active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
1146

M
Megvii Engine Team 已提交
1147 1148 1149
    return outputs


1150
def apply_const_symbolic_mode(value, dtype, device, name):
1151
    graph = active_trace._lazy_eval_graph
1152 1153
    # don't need to unset tracing
    # because varnode construction will ignore tracing flag
1154
    ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
1155 1156
    if np.array(value).ndim == 0:
        setscalar(ret)
1157 1158 1159
    return (ret,)


M
Megvii Engine Team 已提交
1160 1161 1162
def apply_compiled_mode(op: OpDef, *args: RawTensor):
    if skip_tracing:
        args = [
1163
            RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
M
Megvii Engine Team 已提交
1164 1165
            for x in args
        ]
1166 1167 1168 1169
        unset_tracing()
        ret = apply(op, *args)
        set_tracing()
        return ret
M
Megvii Engine Team 已提交
1170 1171 1172
    return active_trace._apply_op(op, args)


1173
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
1174
    if skip_tracing:
1175
        unset_tracing()
1176
        ret = RawTensor(value, dtype, device, False, name)
1177 1178 1179
        set_tracing()
        return ret
    return active_trace._apply_const(value, dtype, device)
1180 1181


M
Megvii Engine Team 已提交
1182
def apply_with_tracing(op: OpDef, *args: RawTensor):
1183 1184 1185
    if active_trace._graph:
        # if member _graph exits, then is_compiled
        return apply_compiled_mode(op, *args)
1186
    if hasattr(op, "scope"):
1187
        op.scope = AutoNaming.get_scope()
1188 1189 1190 1191 1192 1193
    if active_trace._symbolic:
        outputs = apply_symbolic_mode(op, *args)
    else:
        unset_tracing()
        outputs = apply(op, *args)
        set_tracing()
M
Megvii Engine Team 已提交
1194

1195 1196
    active_trace._record_op(op, args, outputs)
    return list(outputs)
M
Megvii Engine Team 已提交
1197 1198


1199
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
1200 1201
    if active_trace._graph:
        return apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name)
1202
    if active_trace._symbolic:
1203
        outputs = apply_const_symbolic_mode(value, dtype, device, name)
1204 1205
    else:
        unset_tracing()
1206 1207 1208 1209
        outputs = RawTensor(value, dtype, device, False, name)
        if np.array(value).ndim == 0:
            setscalar(outputs)
        outputs = (outputs,)
1210 1211 1212
        set_tracing()
    active_trace._record_const(outputs)
    return list(outputs)