test_tracing.py 19.1 KB
Newer Older
1
# -*- coding: utf-8 -*-
2
import inspect
M
Megvii Engine Team 已提交
3
import io
4
import itertools
5
import random
6
from tempfile import mkstemp
M
Megvii Engine Team 已提交
7

M
Megvii Engine Team 已提交
8
import numpy as np
9
import pytest
M
Megvii Engine Team 已提交
10

11
import megengine.core.tensor.megbrain_graph as G
12
import megengine.functional as F
13
import megengine.optimizer as optim
14
import megengine.utils.comp_graph_tools as cgtools
15 16
from megengine import Parameter, tensor
from megengine.autodiff import GradManager
M
Megvii Engine Team 已提交
17
from megengine.core.ops import builtin as ops
18
from megengine.core.ops.builtin import Elemwise
19
from megengine.core.tensor.utils import isscalar
20
from megengine.functional import exp, log
21
from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace
22
from megengine.module import Module
23
from megengine.random import normal, uniform
24
from megengine.utils.naming import AutoNaming
M
Megvii Engine Team 已提交
25 26


27 28 29 30 31 32 33 34 35 36 37 38
@pytest.mark.parametrize("trace_mode", [False, True])
@pytest.mark.parametrize("return_mode", ["Value", "Tuple", "List", "Dict"])
def test_trace(trace_mode, return_mode):
    @trace(symbolic=trace_mode)
    def f(x):
        if return_mode == "Tuple":
            return (-x,)
        elif return_mode == "List":
            return [-x]
        elif return_mode == "Dict":
            return {"neg": -x}
        else:
39
            return -x
M
Megvii Engine Team 已提交
40

41 42 43 44 45 46
    def get_numpy(y):
        if return_mode == "Tuple" or return_mode == "List":
            return y[0].numpy()
        elif return_mode == "Dict":
            return y["neg"].numpy()
        return y.numpy()
M
Megvii Engine Team 已提交
47

48 49 50 51 52
    x = tensor([1])
    y = get_numpy(f(x))

    for i in range(3):
        np.testing.assert_equal(get_numpy(f(x)), y)
M
Megvii Engine Team 已提交
53 54


55 56 57 58 59 60 61 62 63 64 65 66
def test_output_copy_trace():
    class Simple(Module):
        def __init__(self):
            super().__init__()
            self.a = Parameter([1.0], dtype=np.float32)

        def forward(self, x):
            x = x * self.a
            # will result into a copy of output in grad
            x = F.exp(x)
            return x

67
    ys = {False: [], True: []}
68

69 70 71 72 73
    for symbolic in [False, True]:
        net = Simple()
        gm = GradManager().attach(net.parameters())
        opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
        data = tensor(np.arange(4).reshape(2, 2), dtype="float32")
74

75 76 77 78 79 80 81
        @trace(symbolic=symbolic)
        def train_func(d):
            with gm:
                loss = net(d)
                gm.backward(loss)
                opt.step().clear_grad()
            return loss
82

83 84 85
        for i in range(3):
            y = train_func(data).numpy()
            ys[symbolic].append(y)
86

87 88
    for i in range(3):
        np.testing.assert_equal(ys[False][i], ys[True][i])
89

M
Megvii Engine Team 已提交
90

91 92 93 94 95 96 97 98 99 100 101 102 103
@pytest.mark.parametrize("trace_mode", [False, True])
def test_tensor_detach(trace_mode):
    @trace(symbolic=True)
    def f(x):
        y = x.detach() ** 2
        z = y.detach() + 1
        return z.detach()

    x = tensor([1, 2, 3, 4])
    for _ in range(3):
        f(x).numpy()


104 105 106 107 108 109 110 111 112 113
@pytest.mark.parametrize("trace_mode", [False, True])
def test_exclude_from_trace(trace_mode):
    @trace(symbolic=trace_mode)
    def f(x):
        x = -x
        with exclude_from_trace():
            if i % 2:
                x = -x
        x = -x
        return x
M
Megvii Engine Team 已提交
114

115
    x = tensor([1])
M
Megvii Engine Team 已提交
116

117 118 119
    for i in range(3):
        y = f(x).numpy()
        np.testing.assert_equal(f(x).numpy(), y)
M
Megvii Engine Team 已提交
120 121


122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
@pytest.mark.parametrize("trace_mode", [False, True])
def test_elemwise_fuse(trace_mode):
    # explicitly declare opt_level as 2
    @trace(symbolic=trace_mode, opt_level=2)
    def f(a, b):
        base = 0
        c = b - a
        _, idx = F.topk(c, 3)
        # internally, biased_idx will be idx as gopt will ignore the addition
        biased_idx = base + idx
        return biased_idx

    a = tensor(np.ones((7, 2)), dtype=np.int32)
    b = tensor(2 * np.ones((7, 2)), dtype=np.float32)

    for i in range(3):
        y = f(a, b)
        y.numpy()


@pytest.mark.parametrize("trace_mode", [False, True])
def test_elemwise_fuse_in_grad(trace_mode):
    w = Parameter(np.ones([4, 6]), dtype="float32")

    gm = GradManager().attach(w)
    opt = optim.SGD([w], lr=0.01, momentum=0.9, weight_decay=5e-4)

    # explicitly declare opt_level as 2
    @trace(symbolic=trace_mode, opt_level=2)
    def f():
        with gm:
            wm = F.sum(w ** 2, axis=1) ** 0.5
            loss = wm.mean()
            gm.backward(loss)
            opt.step().clear_grad()
        return loss

    for i in range(3):
        y = f()
        y.numpy()


164 165 166 167 168 169 170 171 172 173 174 175
def test_repeat_in_trace():
    @trace(symbolic=False)
    def fun(data, repeats):
        F.repeat(data, repeats)

    data = tensor(np.random.random([1, 2, 3]).astype(np.float32))

    for i in range(1, 5):
        repeats = tensor(i)
        fun(data, repeats)


M
Megvii Engine Team 已提交
176 177 178 179 180 181
def test_print_in_trace():
    for symbolic in [False]:  # cannot read value in symbolic mode

        @trace(symbolic=symbolic)
        def f(x):
            nonlocal buf
182
            x = -x
M
Megvii Engine Team 已提交
183
            buf = x.numpy()
184
            x = -x
M
Megvii Engine Team 已提交
185 186 187
            return x

        buf = None
188
        x = tensor([1])
M
Megvii Engine Team 已提交
189 190

        for i in range(3):
191
            y = f(x).numpy()
M
Megvii Engine Team 已提交
192 193
            z = buf
            buf = None
194
            np.testing.assert_equal(f(x).numpy(), y)
M
Megvii Engine Team 已提交
195
            np.testing.assert_equal(z, buf)
M
Megvii Engine Team 已提交
196 197


198 199 200 201 202 203 204
@pytest.mark.parametrize(
    "dump_format",
    [
        "FBS",
    ],
)
def test_dump(dump_format):
205 206
    @trace(symbolic=True, capture_as_const=True)
    def f(a, b):
207
        return a + b
208

209
    # prevent from remaining scope from exception test
210
    AutoNaming.clear()
211 212 213
    a = tensor([2])
    b = tensor([4])
    y = f(a, b).numpy()
214 215

    for i in range(3):
216
        np.testing.assert_equal(f(a, b).numpy(), y)
217 218

    file = io.BytesIO()
219
    dump_info = f.dump(file, dump_format=dump_format)
220
    assert dump_info.nr_opr == 3
221
    np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
222
    np.testing.assert_equal(dump_info.outputs, ["ADD"])
223
    file.seek(0)
224 225
    infer_cg = cgtools.GraphInference(file)
    result = list((infer_cg.run(a, b)).values())[0]
226 227 228 229
    np.testing.assert_equal(result[0], y)


def test_capture_dump():
230
    a = tensor([2])
231 232 233

    @trace(symbolic=True, capture_as_const=True)
    def f(x):
234
        return x * a
235

236 237
    x = tensor([3])
    y = f(x).numpy()
238 239

    for i in range(3):
240
        np.testing.assert_equal(f(x).numpy(), y)
241 242 243 244

    file = io.BytesIO()
    f.dump(file)
    file.seek(0)
245 246
    infer_cg = cgtools.GraphInference(file)
    result = list((infer_cg.run(x)).values())[0]
247 248 249 250
    np.testing.assert_equal(result[0], y)


def test_dump_volatile():
251
    p = tensor([2])
252

M
Megvii Engine Team 已提交
253 254
    @trace(symbolic=True, capture_as_const=True)
    def f(x):
255
        return x * p
M
Megvii Engine Team 已提交
256

257 258
    x = tensor([3])
    y = f(x).numpy()
M
Megvii Engine Team 已提交
259 260

    for i in range(3):
261
        np.testing.assert_equal(f(x).numpy(), y)
M
Megvii Engine Team 已提交
262 263

    file = io.BytesIO()
264
    f.dump(file, optimize_for_inference=False)
265
    file.seek(0)
266
    (out,) = G.load_graph(file).output_vars_list
267 268
    assert (
        cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
269
        == "ImmutableTensor"
270
    )
271 272


273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
def test_dump_backward_graph():
    x0 = tensor(np.random.randn(3, 4))
    x1 = tensor(np.random.randn(3, 4))

    gm = GradManager().attach(x0)

    @trace(symbolic=True, capture_as_const=True)
    def f(x0, x1):
        with gm:
            y = x0 * x1
            gm.backward(y, F.ones_like(y))
            dx0 = x0.grad
        return y, dx0

    y, dx0 = f(x0, x1)
    np.testing.assert_equal(dx0.numpy(), x1)

    file = io.BytesIO()
    f.dump(file, optimize_for_inference=False)
    file.seek(0)

    infer_cg = cgtools.GraphInference(file)
    results = list((infer_cg.run(x0, x1)).values())

    np.testing.assert_equal(results[0], y)
    np.testing.assert_equal(results[1], dx0)


301 302 303 304 305 306 307 308 309 310
def test_dump_with_testcase():
    @trace(symbolic=True, capture_as_const=True)
    def f(x):
        return exp(x)

    f(tensor(1.0))
    file = io.BytesIO()
    f.dump(file, input_data=["#rand(0, 255, 1)"])


311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
def test_split_dump():
    class SimpleNet(Module):
        def __init__(self, num_segments: int = 3):
            super().__init__()
            self.num_segments = num_segments

        def forward(self, x):
            x = F.split(x, self.num_segments, axis=1)
            return x

    model = SimpleNet()
    model.eval()
    data = tensor(np.random.random((1, 12, 224, 224)))

    @trace(symbolic=True, capture_as_const=True)
    def fun(data, *, net):
        return net(data)

    x = fun(data, net=model)
    fun.dump(io.BytesIO(), arg_names=["data"])


333 334 335 336 337
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_profiler(trace_mode):
    @trace(symbolic=trace_mode, profiling=True)
    def f(x):
        return -x
338

339 340
    x = tensor([1])
    y = f(x).numpy()
341

342 343
    f(x)
    f(x)  # XXX: has to run twice
344

345 346
    out = f.get_profile()
    assert out.get("profiler")
347 348


349
def test_goptions():
350 351
    @trace(symbolic=True, opt_level=0, capture_as_const=True)
    def f(x):
352 353 354 355
        # directly return x / x will not trigger gopt
        # since there's no way to tell the two x are the same
        y = 2.0 * x
        return y / y
356 357 358

    @trace(symbolic=True, opt_level=1, capture_as_const=True)
    def g(x):
359 360
        y = 2.0 * x
        return y / y
361

362 363 364
    d = tensor(0.0)
    assert not np.isfinite(f(d).numpy())
    np.testing.assert_equal(g(d).numpy().item(), 1.0)
365 366 367 368 369 370 371 372 373 374 375


def test_goptions_log_sum_exp():
    @trace(symbolic=True, opt_level=0, capture_as_const=True)
    def f(x, y):
        return log(exp(x) + exp(y))

    @trace(symbolic=True, opt_level=1, capture_as_const=True)
    def g(x, y):
        return log(exp(x) + exp(y))

376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
    val = 1.0e4
    d = tensor(val)
    o = tensor(0.0)
    assert not np.isfinite(f(d, o).numpy())
    np.testing.assert_almost_equal(g(d, o), val)


def test_goptions_log_exp():
    @trace(symbolic=True, opt_level=0, capture_as_const=True)
    def f(x):
        return log(exp(x))

    @trace(symbolic=True, opt_level=1, capture_as_const=True)
    def g(x):
        return log(exp(x))

    f(tensor(1.0))
393
    _, out = mkstemp()
394
    f.dump(out, optimize_for_inference=False)
395
    outputs = G.load_graph(out).output_vars_list
396 397
    oprs_1 = cgtools.get_oprs_seq(outputs)

398
    g(tensor(1.0))
399
    g.dump(out, optimize_for_inference=False)
400
    outputs = G.load_graph(out).output_vars_list
401 402 403 404 405 406 407 408 409 410 411 412
    oprs_2 = cgtools.get_oprs_seq(outputs)

    assert len(oprs_1) - len(oprs_2) == 2


def test_optimize_for_inference():
    @trace(symbolic=True, capture_as_const=True)
    def f(x):
        return exp(x)

    _, out = mkstemp()
    f(tensor(5.0))
413
    f.dump(out, enable_io16xc32=True)
414

415
    res = G.load_graph(out)
416 417
    computing_input = res.output_vars_list[0].owner.inputs[0]
    assert computing_input.dtype == np.float16
418 419


420 421 422
def test_optimize_for_inference_broadcast():
    a = tensor(np.ones(1, dtype=np.float32))

423
    @trace(capture_as_const=True, symbolic_shape=True)
424
    def f():
425
        return a._broadcast(tensor([1, 10], dtype=np.int32))
426 427 428 429 430

    f()
    f.dump(io.BytesIO())


431 432 433 434 435
def test_trace_cvt_bool():
    x = tensor([0], dtype=np.int32)

    @trace(symbolic=True)
    def f(x):
436 437 438 439
        a = x.shape
        b = a[0]
        assert isscalar(b)
        return b == 0
440 441

    for i in range(3):
442
        np.testing.assert_equal(f(x).numpy(), False)
443 444


445 446 447 448 449
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_reshape(trace_mode):
    x1 = tensor(np.random.randn(2, 10, 10))
    x2 = tensor(np.random.randn(4, 10, 10))
    x3 = tensor(np.random.randn(8, 10, 10))
450

451 452 453 454
    @trace(symbolic=trace_mode, capture_as_const=True)
    def f(x):
        y = x.reshape(x.shape[0], 100)
        return y
455

456 457 458
    f(x1)
    f(x2)
    f(x3)
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485


def test_trace_topk():
    x = tensor([5, 2, 7, 1, 0, 3, 2])

    @trace(symbolic=True)
    def f(x):
        y = F.topk(x, 3)
        np.testing.assert_equal(y[0].shape.numpy(), np.array([3,]))
        return y

    for i in range(3):
        f(x)


def test_trace_warp_perspective():
    inp_shape = (1, 1, 4, 4)
    x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
    M_shape = (1, 3, 3)
    M = tensor(
        np.array(
            [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
        ).reshape(M_shape)
    )

    @trace(symbolic=True)
    def f(x, M):
486
        out = F.vision.warp_perspective(x, M, (2, 2))
487 488 489
        np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
        return out

490
    for i in range(3):
491
        f(x, M)
492 493


494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
@pytest.mark.parametrize(
    "normal_expr, mismatch_expr, reason",
    [
        ("a + b + c", "a + b - c", "operator mismatch"),
        ("a + b + 1", "a + b + 2", "tensors not equals"),
        ("((a + b), (b + c))[0]", "a + b", "mismature end"),
        ("a + b + c", "c + (a + b)", "expect internal node, got external"),
        ("c + (a + b)", "a + b + c", "expect external node, got internal"),
        ("a + b + c", "a + b + c + c", "too many instructions"),
        ("((a + b), (b + c))[1]", "((a + b), (b + c))[0]", "data unreadable"),
        ("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"),
    ],
)
def test_trace_mismatch(normal_expr, mismatch_expr, reason):
    a = tensor([1, 2, 3, 4])
    b = tensor([5, 6, 7, 8])
    c = tensor([9, 0, 1, 2])

    mismatch = False

    @trace(symbolic=True)
    def fn(a, b, c):
        if not mismatch:
            result = eval(normal_expr)
        else:
            result = eval(mismatch_expr)
        return result

    for i in range(20):
        try:
            d = fn(a, b, c)
        except TraceError as e:
            assert mismatch
            assert str(e) == "trace error because {}".format(reason)
        except:
            pytest.fail("unexpected trace error")
        else:
            assert not mismatch
            np.testing.assert_equal(d.numpy(), eval(normal_expr).numpy())
        mismatch = random.random() > 0.8
534 535


536
def test_exception_in_trace():
537 538 539 540
    a = tensor([1, 2, 3, 4])
    b = tensor([5, 6, 7, 8])
    c = tensor([9, 0, 1, 2])

541 542 543 544 545 546 547 548 549 550 551
    mismatch = False

    exc = Exception()

    @trace(symbolic=True)
    def fn(a, b, c):
        result = a + b
        if not mismatch:
            result += c
        else:
            raise exc
552 553
        return result

554
    for i in range(20):
555
        try:
556 557 558 559 560 561
            d = fn(a, b, c)
        except TraceError as e:
            pytest.fail("unexpected trace error")
        except Exception as e:
            assert mismatch
            assert e is exc
562
        else:
563
            assert not mismatch
564
            np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
565
        mismatch = random.random() > 0.8
566

567 568 569 570 571 572 573 574 575 576 577 578 579

def test_graph_error():
    a = tensor(np.arange(8).reshape((2, 4)))
    b = tensor(np.arange(8).reshape((2, 4)))

    @trace(symbolic=True)
    def fn(a, b):
        return a + b

    fn(a, b)
    with pytest.raises(RuntimeError):
        fn(a, b.transpose())
    fn(a, b)
580 581


582 583 584 585 586
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_broadcast(trace_mode):
    x1 = tensor(np.random.randn(3, 1, 1))
    x2 = tensor(np.random.randn(1, 4, 1))
    x3 = tensor(np.random.randn(1, 1, 5))
587

588 589 590 591
    @trace(symbolic=trace_mode, capture_as_const=True)
    def f(x):
        y = F.broadcast_to(x, (3, 4, 5))
        return y
592

593 594 595
    f(x1)
    f(x2)
    f(x3)
596 597 598 599 600 601 602 603 604 605 606 607 608 609


def test_trace_nms():
    def make_inputs(n):
        boxes = np.zeros((n, 4))
        boxes[:, :2] = np.random.rand(n, 2) * 100
        boxes[:, 2:] = np.random.rand(n, 2) * 100 + 100

        scores = np.random.rand(n)

        return tensor(boxes), tensor(scores)

    @trace(symbolic=False)
    def f(boxes, scores):
610
        # with tracing, max_output must be specified
611
        results = F.vision.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20)
612
        # without tracing, max output can be inferred inside nms
613
        with exclude_from_trace():
614
            _ = F.vision.nms(boxes, scores=scores, iou_thresh=0.5)
615 616 617 618 619
        return results

    f(*make_inputs(10))
    f(*make_inputs(20))
    f(*make_inputs(30))
620 621 622 623 624 625 626 627 628 629 630 631 632 633


def test_trace_valid_broadcast():
    x1 = tensor(np.random.randn(1, 1))
    x2 = tensor(np.random.randn(1, 2))
    shape = (tensor([2]), tensor([2]))

    @trace(symbolic=False)
    def f(x, shape):
        y = F.broadcast_to(x, shape)
        return y

    f(x1, shape)
    f(x2, shape)
634 635


636 637
@pytest.mark.parametrize("trace_mode", [False, True])
def test_clip(trace_mode):
638 639
    x = tensor(np.random.randn(10, 10))

640
    @trace(symbolic=trace_mode)
641 642 643 644 645 646
    def f(x, lower, upper):
        y = F.clip(x, lower, upper)
        return y

    for i in range(3):
        f(x, tensor([0]), tensor([1]))
647

648 649 650
    for i in range(3):
        f(x, tensor([5]), tensor([4]))

651 652 653 654 655 656 657 658 659 660 661 662

# test returning noncontiguous tensor from trace
def test_slice():
    @trace
    def f(x):
        return x[:, 1::2]

    x = F.arange(8).reshape(2, 4)
    f(x)
    y = f(x)
    np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2])
    y + y
663 664


665 666
@pytest.mark.parametrize("shape_mode", [False, True])
def test_random(shape_mode):
667
    def run_test(op):
668 669 670 671 672 673 674 675 676 677 678
        @trace(symbolic=True, symbolic_shape=shape_mode)
        def f():
            out = op(size=[10, 10])
            out_shape = out.shape
            assert out_shape is not None
            if not isinstance(out_shape, tuple):
                assert out.shape.numpy() is not None
            return out

        for _ in range(3):
            f()
679 680 681

    run_test(uniform)
    run_test(normal)
682 683 684 685 686 687


@pytest.mark.parametrize("shape_mode", [False, True])
def test_trace_advance_indexing(shape_mode):
    funcs = [
        lambda x, i: x[i],
688
        lambda x, i, j: x[i, j],
689
        lambda x, i, j: x[i, :, j, ...],
690
        lambda x, start, end: x[start:end],
691 692 693 694 695 696 697 698 699 700
        lambda x, start, end: x[:, 0, start:end, ..., 1],
        lambda x, vec: x[vec],
        lambda x, vec: x[vec, ..., 0, 1:3],
        lambda x, vec: x[vec, vec[0], vec[1]],
        # lambda x, i, start, end, vec: x[i, ..., :, vec, start:end],  # FIXME
        lambda x, mask: x[mask],
    ]

    inputs = {
        "x": np.random.randn(5, 5, 5, 5, 5).astype("float32"),
701
        "i": 4,
702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
        "j": 2,
        "start": 1,
        "end": 3,
        "vec": [1, 2, 3],
        "mask": np.random.randn(5, 5, 5, 5, 5) >= 0,
    }
    for f in funcs:
        sig = inspect.signature(f)
        param_names = list(sig._parameters.keys())
        params = {}
        params_np = {}
        f_traced = trace(f, symbolic=False, symbolic_shape=shape_mode)
        for name in param_names:
            params[name] = tensor(inputs[name])
            params_np[name] = inputs[name]
        expected = f(**params_np)
        result_imperative = f(**params)
        np.testing.assert_equal(expected, result_imperative.numpy())
        for _ in range(3):
            result_trace = f_traced(**params)
            np.testing.assert_equal(expected, result_trace.numpy())
723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738


@pytest.mark.require_ngpu(1)  # nvrtc backend
def test_trace_jit_config():
    def run(fuse_dimshuffle, fuse_reduce):
        config = GraphOptimizationConfig()
        config.jit_fuse_dimshuffle = fuse_dimshuffle
        config.jit_fuse_reduce = fuse_reduce

        # set opt_level = 1 to avoid fusing dimshuffle and reduce at the same time
        @trace(opt_level=1, graph_opt_config=config)
        def func(x):
            return x + 1

        x = tensor(2)
        y = func(x)
739 740
        y = func(x)
        # func._compile()
741

742
        options = func._trace.options
743 744 745 746 747 748 749 750
        mapping = {None: 0, False: 1, True: 2}
        assert options.graph_opt.jit == 0
        assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle]
        assert options.graph_opt.jit_config.fuse_reduce == mapping[fuse_reduce]

    for fuse_dimshuffle in [None, False, True]:
        for fuse_reduce in [None, False, True]:
            run(fuse_dimshuffle, fuse_reduce)