test_tracing.py 18.3 KB
Newer Older
1
# -*- coding: utf-8 -*-
2
import inspect
M
Megvii Engine Team 已提交
3
import io
4
import itertools
5
import random
6
from tempfile import mkstemp
M
Megvii Engine Team 已提交
7

M
Megvii Engine Team 已提交
8
import numpy as np
9
import pytest
M
Megvii Engine Team 已提交
10

11
import megengine.core.tensor.megbrain_graph as G
12
import megengine.functional as F
13
import megengine.optimizer as optim
14
import megengine.utils.comp_graph_tools as cgtools
15 16
from megengine import Parameter, tensor
from megengine.autodiff import GradManager
M
Megvii Engine Team 已提交
17
from megengine.core.ops import builtin as ops
18
from megengine.core.ops.builtin import Elemwise
19
from megengine.core.tensor.utils import isscalar
20
from megengine.functional import exp, log
21
from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace
22
from megengine.module import Module
23
from megengine.random import normal, uniform
24
from megengine.utils.naming import AutoNaming
M
Megvii Engine Team 已提交
25 26


27 28 29 30 31 32 33 34 35 36 37 38
@pytest.mark.parametrize("trace_mode", [False, True])
@pytest.mark.parametrize("return_mode", ["Value", "Tuple", "List", "Dict"])
def test_trace(trace_mode, return_mode):
    @trace(symbolic=trace_mode)
    def f(x):
        if return_mode == "Tuple":
            return (-x,)
        elif return_mode == "List":
            return [-x]
        elif return_mode == "Dict":
            return {"neg": -x}
        else:
39
            return -x
M
Megvii Engine Team 已提交
40

41 42 43 44 45 46
    def get_numpy(y):
        if return_mode == "Tuple" or return_mode == "List":
            return y[0].numpy()
        elif return_mode == "Dict":
            return y["neg"].numpy()
        return y.numpy()
M
Megvii Engine Team 已提交
47

48 49 50 51 52
    x = tensor([1])
    y = get_numpy(f(x))

    for i in range(3):
        np.testing.assert_equal(get_numpy(f(x)), y)
M
Megvii Engine Team 已提交
53 54


55 56 57 58 59 60 61 62 63 64 65 66
def test_output_copy_trace():
    class Simple(Module):
        def __init__(self):
            super().__init__()
            self.a = Parameter([1.0], dtype=np.float32)

        def forward(self, x):
            x = x * self.a
            # will result into a copy of output in grad
            x = F.exp(x)
            return x

67
    ys = {False: [], True: []}
68

69 70 71 72 73
    for symbolic in [False, True]:
        net = Simple()
        gm = GradManager().attach(net.parameters())
        opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
        data = tensor(np.arange(4).reshape(2, 2), dtype="float32")
74

75 76 77 78 79 80 81
        @trace(symbolic=symbolic)
        def train_func(d):
            with gm:
                loss = net(d)
                gm.backward(loss)
                opt.step().clear_grad()
            return loss
82

83 84 85
        for i in range(3):
            y = train_func(data).numpy()
            ys[symbolic].append(y)
86

87 88
    for i in range(3):
        np.testing.assert_equal(ys[False][i], ys[True][i])
89

M
Megvii Engine Team 已提交
90

91 92 93 94 95 96 97 98 99 100 101 102 103
@pytest.mark.parametrize("trace_mode", [False, True])
def test_tensor_detach(trace_mode):
    @trace(symbolic=True)
    def f(x):
        y = x.detach() ** 2
        z = y.detach() + 1
        return z.detach()

    x = tensor([1, 2, 3, 4])
    for _ in range(3):
        f(x).numpy()


104 105 106 107 108 109 110 111 112 113
@pytest.mark.parametrize("trace_mode", [False, True])
def test_exclude_from_trace(trace_mode):
    @trace(symbolic=trace_mode)
    def f(x):
        x = -x
        with exclude_from_trace():
            if i % 2:
                x = -x
        x = -x
        return x
M
Megvii Engine Team 已提交
114

115
    x = tensor([1])
M
Megvii Engine Team 已提交
116

117 118 119
    for i in range(3):
        y = f(x).numpy()
        np.testing.assert_equal(f(x).numpy(), y)
M
Megvii Engine Team 已提交
120 121


122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
@pytest.mark.parametrize("trace_mode", [False, True])
def test_elemwise_fuse(trace_mode):
    # explicitly declare opt_level as 2
    @trace(symbolic=trace_mode, opt_level=2)
    def f(a, b):
        base = 0
        c = b - a
        _, idx = F.topk(c, 3)
        # internally, biased_idx will be idx as gopt will ignore the addition
        biased_idx = base + idx
        return biased_idx

    a = tensor(np.ones((7, 2)), dtype=np.int32)
    b = tensor(2 * np.ones((7, 2)), dtype=np.float32)

    for i in range(3):
        y = f(a, b)
        y.numpy()


@pytest.mark.parametrize("trace_mode", [False, True])
def test_elemwise_fuse_in_grad(trace_mode):
    w = Parameter(np.ones([4, 6]), dtype="float32")

    gm = GradManager().attach(w)
    opt = optim.SGD([w], lr=0.01, momentum=0.9, weight_decay=5e-4)

    # explicitly declare opt_level as 2
    @trace(symbolic=trace_mode, opt_level=2)
    def f():
        with gm:
            wm = F.sum(w ** 2, axis=1) ** 0.5
            loss = wm.mean()
            gm.backward(loss)
            opt.step().clear_grad()
        return loss

    for i in range(3):
        y = f()
        y.numpy()


M
Megvii Engine Team 已提交
164 165 166 167 168 169
def test_print_in_trace():
    for symbolic in [False]:  # cannot read value in symbolic mode

        @trace(symbolic=symbolic)
        def f(x):
            nonlocal buf
170
            x = -x
M
Megvii Engine Team 已提交
171
            buf = x.numpy()
172
            x = -x
M
Megvii Engine Team 已提交
173 174 175
            return x

        buf = None
176
        x = tensor([1])
M
Megvii Engine Team 已提交
177 178

        for i in range(3):
179
            y = f(x).numpy()
M
Megvii Engine Team 已提交
180 181
            z = buf
            buf = None
182
            np.testing.assert_equal(f(x).numpy(), y)
M
Megvii Engine Team 已提交
183
            np.testing.assert_equal(z, buf)
M
Megvii Engine Team 已提交
184 185


186 187 188 189 190 191 192
@pytest.mark.parametrize(
    "dump_format",
    [
        "FBS",
    ],
)
def test_dump(dump_format):
193 194
    @trace(symbolic=True, capture_as_const=True)
    def f(a, b):
195
        return a + b
196

197
    # prevent from remaining scope from exception test
198
    AutoNaming.clear()
199 200 201
    a = tensor([2])
    b = tensor([4])
    y = f(a, b).numpy()
202 203

    for i in range(3):
204
        np.testing.assert_equal(f(a, b).numpy(), y)
205 206

    file = io.BytesIO()
207
    dump_info = f.dump(file, dump_format=dump_format)
208
    assert dump_info.nr_opr == 3
209
    np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
210
    np.testing.assert_equal(dump_info.outputs, ["ADD"])
211
    file.seek(0)
212 213
    infer_cg = cgtools.GraphInference(file)
    result = list((infer_cg.run(a, b)).values())[0]
214 215 216 217
    np.testing.assert_equal(result[0], y)


def test_capture_dump():
218
    a = tensor([2])
219 220 221

    @trace(symbolic=True, capture_as_const=True)
    def f(x):
222
        return x * a
223

224 225
    x = tensor([3])
    y = f(x).numpy()
226 227

    for i in range(3):
228
        np.testing.assert_equal(f(x).numpy(), y)
229 230 231 232

    file = io.BytesIO()
    f.dump(file)
    file.seek(0)
233 234
    infer_cg = cgtools.GraphInference(file)
    result = list((infer_cg.run(x)).values())[0]
235 236 237 238
    np.testing.assert_equal(result[0], y)


def test_dump_volatile():
239
    p = tensor([2])
240

M
Megvii Engine Team 已提交
241 242
    @trace(symbolic=True, capture_as_const=True)
    def f(x):
243
        return x * p
M
Megvii Engine Team 已提交
244

245 246
    x = tensor([3])
    y = f(x).numpy()
M
Megvii Engine Team 已提交
247 248

    for i in range(3):
249
        np.testing.assert_equal(f(x).numpy(), y)
M
Megvii Engine Team 已提交
250 251

    file = io.BytesIO()
252
    f.dump(file, optimize_for_inference=False)
253
    file.seek(0)
254
    (out,) = G.load_graph(file).output_vars_list
255 256
    assert (
        cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
257
        == "ImmutableTensor"
258
    )
259 260


261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
def test_dump_backward_graph():
    x0 = tensor(np.random.randn(3, 4))
    x1 = tensor(np.random.randn(3, 4))

    gm = GradManager().attach(x0)

    @trace(symbolic=True, capture_as_const=True)
    def f(x0, x1):
        with gm:
            y = x0 * x1
            gm.backward(y, F.ones_like(y))
            dx0 = x0.grad
        return y, dx0

    y, dx0 = f(x0, x1)
    np.testing.assert_equal(dx0.numpy(), x1)

    file = io.BytesIO()
    f.dump(file, optimize_for_inference=False)
    file.seek(0)

    infer_cg = cgtools.GraphInference(file)
    results = list((infer_cg.run(x0, x1)).values())

    np.testing.assert_equal(results[0], y)
    np.testing.assert_equal(results[1], dx0)


289 290 291 292 293 294 295 296 297 298
def test_dump_with_testcase():
    @trace(symbolic=True, capture_as_const=True)
    def f(x):
        return exp(x)

    f(tensor(1.0))
    file = io.BytesIO()
    f.dump(file, input_data=["#rand(0, 255, 1)"])


299 300 301 302 303
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_profiler(trace_mode):
    @trace(symbolic=trace_mode, profiling=True)
    def f(x):
        return -x
304

305 306
    x = tensor([1])
    y = f(x).numpy()
307

308 309
    f(x)
    f(x)  # XXX: has to run twice
310

311 312
    out = f.get_profile()
    assert out.get("profiler")
313 314


315
def test_goptions():
316 317
    @trace(symbolic=True, opt_level=0, capture_as_const=True)
    def f(x):
318 319 320 321
        # directly return x / x will not trigger gopt
        # since there's no way to tell the two x are the same
        y = 2.0 * x
        return y / y
322 323 324

    @trace(symbolic=True, opt_level=1, capture_as_const=True)
    def g(x):
325 326
        y = 2.0 * x
        return y / y
327

328 329 330
    d = tensor(0.0)
    assert not np.isfinite(f(d).numpy())
    np.testing.assert_equal(g(d).numpy().item(), 1.0)
331 332 333 334 335 336 337 338 339 340 341


def test_goptions_log_sum_exp():
    @trace(symbolic=True, opt_level=0, capture_as_const=True)
    def f(x, y):
        return log(exp(x) + exp(y))

    @trace(symbolic=True, opt_level=1, capture_as_const=True)
    def g(x, y):
        return log(exp(x) + exp(y))

342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
    val = 1.0e4
    d = tensor(val)
    o = tensor(0.0)
    assert not np.isfinite(f(d, o).numpy())
    np.testing.assert_almost_equal(g(d, o), val)


def test_goptions_log_exp():
    @trace(symbolic=True, opt_level=0, capture_as_const=True)
    def f(x):
        return log(exp(x))

    @trace(symbolic=True, opt_level=1, capture_as_const=True)
    def g(x):
        return log(exp(x))

    f(tensor(1.0))
359
    _, out = mkstemp()
360
    f.dump(out, optimize_for_inference=False)
361
    outputs = G.load_graph(out).output_vars_list
362 363
    oprs_1 = cgtools.get_oprs_seq(outputs)

364
    g(tensor(1.0))
365
    g.dump(out, optimize_for_inference=False)
366
    outputs = G.load_graph(out).output_vars_list
367 368 369 370 371 372 373 374 375 376 377 378
    oprs_2 = cgtools.get_oprs_seq(outputs)

    assert len(oprs_1) - len(oprs_2) == 2


def test_optimize_for_inference():
    @trace(symbolic=True, capture_as_const=True)
    def f(x):
        return exp(x)

    _, out = mkstemp()
    f(tensor(5.0))
379
    f.dump(out, enable_io16xc32=True)
380

381
    res = G.load_graph(out)
382 383
    computing_input = res.output_vars_list[0].owner.inputs[0]
    assert computing_input.dtype == np.float16
384 385


386 387 388
def test_optimize_for_inference_broadcast():
    a = tensor(np.ones(1, dtype=np.float32))

389
    @trace(capture_as_const=True, symbolic_shape=True)
390
    def f():
391
        return a._broadcast(tensor([1, 10], dtype=np.int32))
392 393 394 395 396

    f()
    f.dump(io.BytesIO())


397 398 399 400 401
def test_trace_cvt_bool():
    x = tensor([0], dtype=np.int32)

    @trace(symbolic=True)
    def f(x):
402 403 404 405
        a = x.shape
        b = a[0]
        assert isscalar(b)
        return b == 0
406 407

    for i in range(3):
408
        np.testing.assert_equal(f(x).numpy(), False)
409 410


411 412 413 414 415
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_reshape(trace_mode):
    x1 = tensor(np.random.randn(2, 10, 10))
    x2 = tensor(np.random.randn(4, 10, 10))
    x3 = tensor(np.random.randn(8, 10, 10))
416

417 418 419 420
    @trace(symbolic=trace_mode, capture_as_const=True)
    def f(x):
        y = x.reshape(x.shape[0], 100)
        return y
421

422 423 424
    f(x1)
    f(x2)
    f(x3)
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451


def test_trace_topk():
    x = tensor([5, 2, 7, 1, 0, 3, 2])

    @trace(symbolic=True)
    def f(x):
        y = F.topk(x, 3)
        np.testing.assert_equal(y[0].shape.numpy(), np.array([3,]))
        return y

    for i in range(3):
        f(x)


def test_trace_warp_perspective():
    inp_shape = (1, 1, 4, 4)
    x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
    M_shape = (1, 3, 3)
    M = tensor(
        np.array(
            [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
        ).reshape(M_shape)
    )

    @trace(symbolic=True)
    def f(x, M):
452
        out = F.vision.warp_perspective(x, M, (2, 2))
453 454 455
        np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
        return out

456
    for i in range(3):
457
        f(x, M)
458 459


460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
@pytest.mark.parametrize(
    "normal_expr, mismatch_expr, reason",
    [
        ("a + b + c", "a + b - c", "operator mismatch"),
        ("a + b + 1", "a + b + 2", "tensors not equals"),
        ("((a + b), (b + c))[0]", "a + b", "mismature end"),
        ("a + b + c", "c + (a + b)", "expect internal node, got external"),
        ("c + (a + b)", "a + b + c", "expect external node, got internal"),
        ("a + b + c", "a + b + c + c", "too many instructions"),
        ("((a + b), (b + c))[1]", "((a + b), (b + c))[0]", "data unreadable"),
        ("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"),
    ],
)
def test_trace_mismatch(normal_expr, mismatch_expr, reason):
    a = tensor([1, 2, 3, 4])
    b = tensor([5, 6, 7, 8])
    c = tensor([9, 0, 1, 2])

    mismatch = False

    @trace(symbolic=True)
    def fn(a, b, c):
        if not mismatch:
            result = eval(normal_expr)
        else:
            result = eval(mismatch_expr)
        return result

    for i in range(20):
        try:
            d = fn(a, b, c)
        except TraceError as e:
            assert mismatch
            assert str(e) == "trace error because {}".format(reason)
        except:
            pytest.fail("unexpected trace error")
        else:
            assert not mismatch
            np.testing.assert_equal(d.numpy(), eval(normal_expr).numpy())
        mismatch = random.random() > 0.8
500 501


502
def test_exception_in_trace():
503 504 505 506
    a = tensor([1, 2, 3, 4])
    b = tensor([5, 6, 7, 8])
    c = tensor([9, 0, 1, 2])

507 508 509 510 511 512 513 514 515 516 517
    mismatch = False

    exc = Exception()

    @trace(symbolic=True)
    def fn(a, b, c):
        result = a + b
        if not mismatch:
            result += c
        else:
            raise exc
518 519
        return result

520
    for i in range(20):
521
        try:
522 523 524 525 526 527
            d = fn(a, b, c)
        except TraceError as e:
            pytest.fail("unexpected trace error")
        except Exception as e:
            assert mismatch
            assert e is exc
528
        else:
529
            assert not mismatch
530
            np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
531
        mismatch = random.random() > 0.8
532

533 534 535 536 537 538 539 540 541 542 543 544 545

def test_graph_error():
    a = tensor(np.arange(8).reshape((2, 4)))
    b = tensor(np.arange(8).reshape((2, 4)))

    @trace(symbolic=True)
    def fn(a, b):
        return a + b

    fn(a, b)
    with pytest.raises(RuntimeError):
        fn(a, b.transpose())
    fn(a, b)
546 547


548 549 550 551 552
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_broadcast(trace_mode):
    x1 = tensor(np.random.randn(3, 1, 1))
    x2 = tensor(np.random.randn(1, 4, 1))
    x3 = tensor(np.random.randn(1, 1, 5))
553

554 555 556 557
    @trace(symbolic=trace_mode, capture_as_const=True)
    def f(x):
        y = F.broadcast_to(x, (3, 4, 5))
        return y
558

559 560 561
    f(x1)
    f(x2)
    f(x3)
562 563 564 565 566 567 568 569 570 571 572 573 574 575


def test_trace_nms():
    def make_inputs(n):
        boxes = np.zeros((n, 4))
        boxes[:, :2] = np.random.rand(n, 2) * 100
        boxes[:, 2:] = np.random.rand(n, 2) * 100 + 100

        scores = np.random.rand(n)

        return tensor(boxes), tensor(scores)

    @trace(symbolic=False)
    def f(boxes, scores):
576
        # with tracing, max_output must be specified
577
        results = F.vision.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20)
578
        # without tracing, max output can be inferred inside nms
579
        with exclude_from_trace():
580
            _ = F.vision.nms(boxes, scores=scores, iou_thresh=0.5)
581 582 583 584 585
        return results

    f(*make_inputs(10))
    f(*make_inputs(20))
    f(*make_inputs(30))
586 587 588 589 590 591 592 593 594 595 596 597 598 599


def test_trace_valid_broadcast():
    x1 = tensor(np.random.randn(1, 1))
    x2 = tensor(np.random.randn(1, 2))
    shape = (tensor([2]), tensor([2]))

    @trace(symbolic=False)
    def f(x, shape):
        y = F.broadcast_to(x, shape)
        return y

    f(x1, shape)
    f(x2, shape)
600 601


602 603
@pytest.mark.parametrize("trace_mode", [False, True])
def test_clip(trace_mode):
604 605
    x = tensor(np.random.randn(10, 10))

606
    @trace(symbolic=trace_mode)
607 608 609 610 611 612
    def f(x, lower, upper):
        y = F.clip(x, lower, upper)
        return y

    for i in range(3):
        f(x, tensor([0]), tensor([1]))
613

614 615 616
    for i in range(3):
        f(x, tensor([5]), tensor([4]))

617 618 619 620 621 622 623 624 625 626 627 628

# test returning noncontiguous tensor from trace
def test_slice():
    @trace
    def f(x):
        return x[:, 1::2]

    x = F.arange(8).reshape(2, 4)
    f(x)
    y = f(x)
    np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2])
    y + y
629 630


631 632
@pytest.mark.parametrize("shape_mode", [False, True])
def test_random(shape_mode):
633
    def run_test(op):
634 635 636 637 638 639 640 641 642 643 644
        @trace(symbolic=True, symbolic_shape=shape_mode)
        def f():
            out = op(size=[10, 10])
            out_shape = out.shape
            assert out_shape is not None
            if not isinstance(out_shape, tuple):
                assert out.shape.numpy() is not None
            return out

        for _ in range(3):
            f()
645 646 647

    run_test(uniform)
    run_test(normal)
648 649 650 651 652 653


@pytest.mark.parametrize("shape_mode", [False, True])
def test_trace_advance_indexing(shape_mode):
    funcs = [
        lambda x, i: x[i],
654
        lambda x, i, j: x[i, j],
655
        lambda x, i, j: x[i, :, j, ...],
656
        lambda x, start, end: x[start:end],
657 658 659 660 661 662 663 664 665 666
        lambda x, start, end: x[:, 0, start:end, ..., 1],
        lambda x, vec: x[vec],
        lambda x, vec: x[vec, ..., 0, 1:3],
        lambda x, vec: x[vec, vec[0], vec[1]],
        # lambda x, i, start, end, vec: x[i, ..., :, vec, start:end],  # FIXME
        lambda x, mask: x[mask],
    ]

    inputs = {
        "x": np.random.randn(5, 5, 5, 5, 5).astype("float32"),
667
        "i": 4,
668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688
        "j": 2,
        "start": 1,
        "end": 3,
        "vec": [1, 2, 3],
        "mask": np.random.randn(5, 5, 5, 5, 5) >= 0,
    }
    for f in funcs:
        sig = inspect.signature(f)
        param_names = list(sig._parameters.keys())
        params = {}
        params_np = {}
        f_traced = trace(f, symbolic=False, symbolic_shape=shape_mode)
        for name in param_names:
            params[name] = tensor(inputs[name])
            params_np[name] = inputs[name]
        expected = f(**params_np)
        result_imperative = f(**params)
        np.testing.assert_equal(expected, result_imperative.numpy())
        for _ in range(3):
            result_trace = f_traced(**params)
            np.testing.assert_equal(expected, result_trace.numpy())
689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704


@pytest.mark.require_ngpu(1)  # nvrtc backend
def test_trace_jit_config():
    def run(fuse_dimshuffle, fuse_reduce):
        config = GraphOptimizationConfig()
        config.jit_fuse_dimshuffle = fuse_dimshuffle
        config.jit_fuse_reduce = fuse_reduce

        # set opt_level = 1 to avoid fusing dimshuffle and reduce at the same time
        @trace(opt_level=1, graph_opt_config=config)
        def func(x):
            return x + 1

        x = tensor(2)
        y = func(x)
705 706
        y = func(x)
        # func._compile()
707

708
        options = func._trace.options
709 710 711 712 713 714 715 716
        mapping = {None: 0, False: 1, True: 2}
        assert options.graph_opt.jit == 0
        assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle]
        assert options.graph_opt.jit_config.fuse_reduce == mapping[fuse_reduce]

    for fuse_dimshuffle in [None, False, True]:
        for fuse_reduce in [None, False, True]:
            run(fuse_dimshuffle, fuse_reduce)