test_tensor.py 25.1 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
import os
10 11
import platform

12 13
import numpy as np
import pytest
14
from utils import get_var_value, make_tensor, opr_test
15 16

import megengine.functional as F
M
Megvii Engine Team 已提交
17
from megengine import tensor
18
from megengine.core._trace_option import use_symbolic_shape
19
from megengine.core.tensor import megbrain_graph as G
20
from megengine.core.tensor.utils import astensor1d
21
from megengine.jit import trace
22
from megengine.utils.network import Network, set_symbolic_shape
23
from megengine.utils.network_node import VarNode
24 25 26


def test_eye():
27
    dtypes = [np.float32, np.bool]
28
    cases = [{"input": [10, 20]}, {"input": [30]}]
29 30 31 32 33 34 35 36 37 38 39 40 41 42
    for dtype in dtypes:
        for case in cases:
            np.testing.assert_allclose(
                F.eye(case["input"], dtype=dtype).numpy(),
                np.eye(*case["input"]).astype(dtype),
            )
            np.testing.assert_allclose(
                F.eye(*case["input"], dtype=dtype).numpy(),
                np.eye(*case["input"]).astype(dtype),
            )
            np.testing.assert_allclose(
                F.eye(tensor(case["input"]), dtype=dtype).numpy(),
                np.eye(*case["input"]).astype(dtype),
            )
43 44


45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
@pytest.mark.parametrize("is_varnode", [False, True])
def test_diag(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    shapes = [(10, 10), (6, 9), (8, 7), (8,)]
    cases = []
    for shp in shapes:
        cases.append({"input": [np.random.random(shp).astype("float32")]})

    for axis in range(-2, 3):

        def run(data):
            return F.diag(data, k=axis)

        opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network)


65 66 67 68 69 70 71 72
def test_full():
    shape = (2, 3)
    values = [True, 4, 5.0]
    for value in values:
        np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
        assert F.full(shape, value).dtype == tensor(value).dtype


73 74 75 76 77 78 79
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

80 81 82 83 84 85 86 87 88 89 90
    def get_data_shape(length: int):
        return (length, 2, 3)

    data1 = np.random.random(get_data_shape(5)).astype("float32")
    data2 = np.random.random(get_data_shape(6)).astype("float32")
    data3 = np.random.random(get_data_shape(7)).astype("float32")

    def run(data1, data2):
        return F.concat([data1, data2])

    cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
91
    opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
92 93


94 95 96 97 98 99 100 101 102 103 104 105
@pytest.mark.parametrize("is_varnode", [True, False])
def test_condtake(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
    y = np.array([[True, False, True], [False, True, True]])
    xx = make_tensor(x, network)
    yy = make_tensor(y, network)
    val, idx = F.cond_take(yy, xx)
106 107 108 109 110 111
    if is_varnode:
        np.testing.assert_equal(get_var_value(val), x[y])
        np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
    else:
        np.testing.assert_equal(val.numpy(), x[y])
        np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
112 113


114 115 116 117 118 119 120 121 122
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat_device(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
    data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
123 124 125 126 127

    out = F.concat([data1, data2], device="cpu0")
    assert str(out.device).split(":")[0] == "cpu0"


128 129 130 131 132 133 134
@pytest.mark.parametrize("is_varnode", [True, False])
def test_stack(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

135 136 137 138 139 140 141 142 143 144
    data1 = np.random.random((3, 2, 2)).astype("float32")
    data2 = np.random.random((3, 2, 2)).astype("float32")
    data3 = np.random.random((3, 2, 2)).astype("float32")

    cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
    for ai in range(3):

        def run(data1, data2):
            return F.stack([data1, data2], axis=ai)

145 146 147 148
        opr_test(
            cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
        )

149

150
@pytest.mark.parametrize("is_varnode", [True, False])
151
def test_split_basic(is_varnode):
152 153
    if is_varnode:
        network = Network()
154
        saved_symbolic_shape = set_symbolic_shape(False)
155 156
    else:
        network = None
157 158

    data = np.random.random((2, 3, 4, 5)).astype(np.float32)
159
    inp = make_tensor(data, network)
160 161 162

    mge_out0 = F.split(inp, 2, axis=3)
    mge_out1 = F.split(inp, [3], axis=3)
163 164 165

    np_out = np.split(data, [3, 5], axis=3)

166 167 168 169
    assert len(mge_out0) == 2
    assert len(mge_out1) == 2

    np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
170 171
    np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])

172 173 174 175 176 177 178 179 180 181
    np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
    np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])

    try:
        F.split(inp, 4)
        assert False
    except ValueError as e:
        pass

    try:
182
        F.split(inp, [3, 2, 5], axis=3)
183 184
        assert False
    except ValueError as e:
185
        assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
186

187 188 189
    if is_varnode:
        set_symbolic_shape(saved_symbolic_shape)

190

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
@pytest.mark.parametrize("symbolic", [None, False, True])
def test_split(symbolic):
    inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
    inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)

    def ref(inp, nsplits_or_sections, axis):
        return np.split(inp, nsplits_or_sections, axis)

    def func(inp, nsplits_or_sections, axis):
        return F.split(inp, nsplits_or_sections, axis)

    cases = [
        (inp1, 2, 3),
        (inp1, [3], 3),
        (inp1, [3, 3, 5], 3),
        (inp2, 2, 3),
        (inp2, [3], 3),
        (inp2, [3, 3, 5], 3),
    ]

    for case in cases:
        if symbolic is None:
            fn = func
        else:
            fn = trace(symbolic=symbolic)(func)
        for i in range(3 if symbolic is not None else 1):
            ref_out = ref(*case)
            out = fn(tensor(case[0]), case[1], case[2])
            assert len(ref_out) == len(out)
            for idx in range(len(ref_out)):
                np.testing.assert_equal(ref_out[idx], out[idx].numpy())


224 225 226 227 228 229 230
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

231
    x = np.arange(6, dtype="float32")
232
    xx = make_tensor(x, network)
233 234 235 236 237
    y = x.reshape(1, 2, 3)

    for shape in [
        (1, 2, 3),
        (1, -1, 3),
238
        (1, make_tensor(-1, network), 3),
239
        np.array([1, -1, 3], dtype="int32"),
240
        make_tensor([1, -1, 3], network),
241 242 243 244 245
    ]:
        yy = F.reshape(xx, shape)
        np.testing.assert_equal(yy.numpy(), y)


246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
@pytest.mark.parametrize("is_varnode", [True, False])
def test_broadcast_auto_infer(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    x = np.random.random((1, 2, 3)).astype(np.float32)
    xx = make_tensor(x, network)

    for shape in [
        (1, 2, 3),
        (1, None, 3),
    ]:
        yy = F.broadcast_to(xx, shape)
        np.testing.assert_equal(yy.numpy(), x)

    with pytest.raises(ValueError):
        F.broadcast_to(xx, (1, -1, 3))

    with pytest.raises(ValueError):
        F.broadcast_to(xx, (None, 1, 2, 3))

    F.broadcast_to(xx, (1, None, 2, 3))
270
    t = make_tensor(2, network)
271 272 273
    F.broadcast_to(xx, (t, None, 2, 3))


274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
@pytest.mark.parametrize("is_trace", [True, False])
def test_reshape_on_empty_tensor(is_trace):
    input1_shape = (100, 0, 1)
    output1_shape = (100, 0, 10)
    data1 = tensor(np.random.random(input1_shape).astype(np.float32))

    input2_shape = (10, 0)
    output2_shape = (0,)
    data2 = tensor(np.random.random(input2_shape).astype(np.float32))

    input3_shape = (10, 0, 10)
    output3_shape = (0, 1, 2, 3)
    data3 = tensor(np.random.random(input3_shape).astype(np.float32))

    def comp(out, target_shp):
        assert out._tuple_shape == target_shp

    def func(x, shp):
        return F.reshape(x, shp)

    cases = [
        [data1, output1_shape],
        [data2, output2_shape],
        [data3, output3_shape],
    ]

    def test(func, inp, comp, target_shp):
        out = func(inp, target_shp)
        comp(out, target_shp)

    if is_trace:
        for symbolic in [False, True]:
            for inp, target_shp in cases:
                func_traced = trace(symbolic=symbolic)(func)
                test(func_traced, inp, comp, target_shp)
                test(func_traced, inp, comp, target_shp)
                test(func_traced, inp, comp, target_shp)
    else:
        for inp, target_shp in cases:
            test(func, inp, comp, target_shp)


316 317 318 319
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape_shape_inference(is_varnode):
    if is_varnode:
        network = Network()
320
        saved_symbolic_shape = set_symbolic_shape(False)
321 322 323 324 325 326 327 328 329 330
    else:
        network = None

    x_shape_known = make_tensor([1, 2, 3, 4], network)
    x_shape_unknown = F.broadcast_to(
        make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
    )
    tshp_unknown = astensor1d(
        (make_tensor([2], network), make_tensor([2], network)), x_shape_known
    )
331 332 333 334 335 336 337
    tshp_known = astensor1d((2, 2), x_shape_known)
    tshp_known_unspec = astensor1d((2, -1), x_shape_known)

    def check_shape(output, target):
        source = output.shape
        if isinstance(source, tensor):
            source = source.numpy()
338
        np.testing.assert_equal(source, target.shape)
339 340 341 342 343

    def func(x, target_shape):
        return x.reshape(target_shape)

    cases = [
344 345 346 347 348 349
        {"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]},
        {"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]},
        {"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]},
        {"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
        {"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]},
        {"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
350
    ]
351
    opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
352 353
    if is_varnode:
        set_symbolic_shape(saved_symbolic_shape)
354

355

356 357 358 359
@pytest.mark.parametrize("is_varnode", [True, False])
def test_squeeze(is_varnode):
    if is_varnode:
        network = Network()
360
        saved_symbolic_shape = set_symbolic_shape(False)
361 362
    else:
        network = None
363

364
    x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
365
    xx = make_tensor(x, network)
366 367 368

    for axis in [None, 3, -4, (3, -4)]:
        y = np.squeeze(x, axis)
369
        yy = F.squeeze(xx, axis)
370 371
        np.testing.assert_equal(y, yy.numpy())

372 373 374
    if is_varnode:
        set_symbolic_shape(saved_symbolic_shape)

375

376 377 378 379 380 381 382
@pytest.mark.parametrize("is_varnode", [True, False])
def test_expand_dims(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

383
    x = np.arange(6, dtype="float32").reshape(2, 3)
384
    xx = make_tensor(x, network)
385 386 387

    for axis in [2, -3, (3, -4), (1, -4)]:
        y = np.expand_dims(x, axis)
388
        yy = F.expand_dims(xx, axis)
389 390 391
        np.testing.assert_equal(y, yy.numpy())


392 393 394 395 396 397 398 399 400 401
def test_expand_dims_for_scalar():
    x = np.array(1, dtype="float32")
    xx = make_tensor(x, None)
    for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
        y = np.expand_dims(x, axis)
        yy = F.expand_dims(xx, axis)
        np.testing.assert_equal(y, yy.numpy())

    for axis in [1, -2, (1, 2), (-2, -3)]:
        np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
402
        np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)
403 404


405 406 407 408 409 410 411
@pytest.mark.parametrize("is_varnode", [True, False])
def test_elemwise_dtype_promotion(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

412 413
    x = np.random.rand(2, 3).astype("float32")
    y = np.random.rand(1, 3).astype("float16")
414 415
    xx = make_tensor(x, network)
    yy = make_tensor(y, network)
416 417 418 419 420 421 422 423 424 425
    z = xx * yy
    np.testing.assert_equal(z.numpy(), x * y)

    z = xx + y
    np.testing.assert_equal(z.numpy(), x + y)

    z = x - yy
    np.testing.assert_equal(z.numpy(), x - y)


426 427 428 429 430 431 432
@pytest.mark.parametrize("is_varnode", [True, False])
def test_linspace(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

433 434 435 436 437 438 439 440
    cases = [
        {"input": [1, 9, 9]},
        {"input": [3, 10, 8]},
    ]
    opr_test(
        cases,
        F.linspace,
        ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
441
        network=network,
442 443 444 445 446 447 448 449 450 451
    )

    cases = [
        {"input": [9, 1, 9]},
        {"input": [10, 3, 8]},
    ]
    opr_test(
        cases,
        F.linspace,
        ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
452
        network=network,
453 454 455
    )

    cases = [
456 457
        {"input": [1, make_tensor(9, network), 9]},
        {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
458 459 460 461 462
    ]
    opr_test(
        cases,
        F.linspace,
        ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
463
        network=network,
464 465 466
    )


467 468 469 470 471 472 473
@pytest.mark.parametrize("is_varnode", [True, False])
def test_arange(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

474 475 476 477 478 479 480 481
    cases = [
        {"input": [1, 9, 1]},
        {"input": [2, 10, 2]},
    ]
    opr_test(
        cases,
        F.arange,
        ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
482
        network=network,
483 484 485 486 487 488 489 490 491 492
    )

    cases = [
        {"input": [9, 1, -1]},
        {"input": [10, 2, -2]},
    ]
    opr_test(
        cases,
        F.arange,
        ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
493
        network=network,
494 495 496 497 498 499 500 501 502 503
    )

    cases = [
        {"input": [9.3, 1.2, -0.5]},
        {"input": [10.3, 2.1, -1.7]},
    ]
    opr_test(
        cases,
        F.arange,
        ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
504
        network=network,
505 506 507
    )


508 509 510 511 512 513 514
@pytest.mark.parametrize("is_varnode", [True, False])
def test_round(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

515 516 517 518 519 520
    data1_shape = (15,)
    data2_shape = (25,)
    data1 = np.random.random(data1_shape).astype(np.float32)
    data2 = np.random.random(data2_shape).astype(np.float32)

    cases = [{"input": data1}, {"input": data2}]
521
    opr_test(cases, F.round, ref_fn=np.round, network=network)
522 523


524 525 526 527 528 529 530
@pytest.mark.parametrize("is_varnode", [True, False])
def test_flatten(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

531 532 533 534 535 536
    data0_shape = (2, 3, 4, 5)
    data1_shape = (4, 5, 6, 7)
    data0 = np.random.random(data0_shape).astype(np.float32)
    data1 = np.random.random(data1_shape).astype(np.float32)

    cases = [
537 538
        {"input": data0, "output": data0.flatten()},
        {"input": data1, "output": data1.flatten()},
539
    ]
540
    opr_test(cases, F.flatten, network=network)
541 542

    cases = [
543 544
        {"input": data0, "output": data0.reshape(2, -1)},
        {"input": data1, "output": data1.reshape(4, -1)},
545
    ]
546
    opr_test(cases, F.flatten, start_axis=1, network=network)
547 548

    cases = [
549 550
        {"input": data0, "output": data0.reshape(2, 3, -1)},
        {"input": data1, "output": data1.reshape(4, 5, -1)},
551
    ]
552
    opr_test(cases, F.flatten, start_axis=2, network=network)
553 554

    cases = [
555 556
        {"input": data0, "output": data0.reshape(2, -1, 5)},
        {"input": data1, "output": data1.reshape(4, -1, 7)},
557
    ]
558
    opr_test(
559
        cases, F.flatten, start_axis=1, end_axis=2, network=network,
560 561
    )

562

563 564 565 566 567 568
@pytest.mark.parametrize("is_varnode", [True, False])
def test_broadcast(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None
569

570 571 572 573
    input1_shape = (20, 30)
    output1_shape = (30, 20, 30)
    data1 = np.random.random(input1_shape).astype(np.float32)

574
    input2_shape = (10, 1)
575 576 577
    output2_shape = (20, 10, 20)
    data2 = np.random.random(input2_shape).astype(np.float32)

578 579 580 581
    input3_shape = (10, 10)
    output3_shape = (10, 10)
    data3 = np.random.random(input3_shape).astype(np.float32)

582
    cases = [
583 584 585 586 587 588 589 590 591 592 593 594
        {
            "input": [data1, output1_shape],
            "output": np.broadcast_to(data1, output1_shape),
        },
        {
            "input": [data2, output2_shape],
            "output": np.broadcast_to(data2, output2_shape),
        },
        {
            "input": [data3, output3_shape],
            "output": np.broadcast_to(data3, output3_shape),
        },
595
    ]
596 597

    opr_test(cases, F.broadcast_to, network=network)
598

599
    x = F.ones((2, 1, 3))
600
    with pytest.raises(RuntimeError):
601
        F.broadcast_to(x, (2, 3, 4))
602

603
    with pytest.raises(RuntimeError):
604
        F.broadcast_to(x, (4, 1, 3))
605

606
    with pytest.raises(RuntimeError):
607
        F.broadcast_to(x, (1, 3))
608

609

610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
@pytest.mark.parametrize("is_trace", [True, False])
def test_broadcast_on_empty_tensor(is_trace):
    input1_shape = (100, 0, 1)
    output1_shape = (100, 0, 10)
    data1 = tensor(np.random.random(input1_shape).astype(np.float32))

    input2_shape = (10, 0)
    output2_shape = (10, 10, 0)
    data2 = tensor(np.random.random(input2_shape).astype(np.float32))

    input3_shape = (0, 0, 1, 10)
    output3_shape = (10, 0, 0, 10, 10)
    data3 = tensor(np.random.random(input3_shape).astype(np.float32))

    def comp(out, target_shp):
        assert out._tuple_shape == target_shp

    def func(x, shp):
        return F.broadcast_to(x, shp)

    cases = [
        [data1, output1_shape],
        [data2, output2_shape],
        [data3, output3_shape],
    ]

    def test(func, inp, comp, target_shp):
        out = func(inp, target_shp)
        comp(out, target_shp)

    if is_trace:
        for symbolic in [False, True]:
            for inp, target_shp in cases:
                func_traced = trace(symbolic=symbolic)(func)
                test(func_traced, inp, comp, target_shp)
                test(func_traced, inp, comp, target_shp)
                test(func_traced, inp, comp, target_shp)
    else:
        for inp, target_shp in cases:
            test(func, inp, comp, target_shp)


652 653 654 655 656 657 658 659
@pytest.mark.parametrize("is_varnode", [True, False])
def test_utils_astensor1d(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    reference = make_tensor(0, network)
660 661 662 663 664

    # literal
    x = [1, 2, 3]
    for dtype in [None, "float32"]:
        xx = astensor1d(x, reference, dtype=dtype)
665
        assert isinstance(xx, type(reference))
666 667 668 669 670 671
        np.testing.assert_equal(xx.numpy(), x)

    # numpy array
    x = np.asarray([1, 2, 3], dtype="int32")
    for dtype in [None, "float32"]:
        xx = astensor1d(x, reference, dtype=dtype)
672
        assert isinstance(xx, type(reference))
673 674 675
        np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)

    # tensor
676
    x = make_tensor([1, 2, 3], network)
677 678
    for dtype in [None, "float32"]:
        xx = astensor1d(x, reference, dtype=dtype)
679
        assert isinstance(xx, type(reference))
680 681 682
        np.testing.assert_equal(xx.numpy(), x.numpy())

    # mixed
683
    x = [1, make_tensor(2, network), 3]
684 685
    for dtype in [None, "float32"]:
        xx = astensor1d(x, reference, dtype=dtype)
686
        assert isinstance(xx, type(reference))
687 688 689 690 691 692 693 694 695 696 697
        np.testing.assert_equal(xx.numpy(), [1, 2, 3])


def test_device():
    x = tensor([1, 2, 3], dtype="float32")

    y1 = F.eye(x.shape, dtype="float32")
    y2 = F.eye(x.shape, dtype="float32", device=None)
    np.testing.assert_almost_equal(y1.numpy(), y2.numpy())

    y3 = F.eye(x.shape, dtype="float32", device="xpux")
698
    y4 = F.eye(x.shape, dtype="float32", device=x.device)
699 700 701 702 703
    np.testing.assert_almost_equal(y3.numpy(), y4.numpy())

    y5 = F.full((3, 2), 4, device=x.device)
    y6 = F.full((3, 2), 4, device="xpux")
    np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
704 705


706 707 708 709 710 711 712 713
@pytest.mark.parametrize("is_varnode", [True, False])
def test_identity(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
M
Megvii Engine Team 已提交
714
    y = F.copy(x)
715 716 717
    np.testing.assert_equal(y.numpy(), x)


718
def copy_test(dst, src, network):
719
    data = np.random.random((2, 3)).astype(np.float32)
720
    x = make_tensor(data, device=src, network=network)
721 722
    y = F.copy(x, dst)
    assert np.allclose(data, y.numpy())
723 724 725
    if network is None:
        z = x.to(dst)
        assert np.allclose(data, z.numpy())
726 727


728
@pytest.mark.require_ngpu(1)
729 730 731 732 733 734 735 736
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_h2d(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    copy_test("cpu0", "gpu0", network=network)
737 738


739
@pytest.mark.require_ngpu(1)
740 741 742 743 744 745 746 747
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_d2h(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    copy_test("gpu0", "cpu0", network=network)
748 749


750
@pytest.mark.require_ngpu(2)
751 752 753 754 755 756 757 758 759
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_d2d(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    copy_test("gpu0", "gpu1", network=network)
    copy_test("gpu0:0", "gpu0:1", network=network)
760 761


762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize(
    "shape, device_src, device_dst",
    [
        ((0,), "cpu0", "cpu0"),
        ((10, 0), "cpu0", "cpu1"),
        ((2, 0, 3), "cpu0", "gpu0"),
        ((1, 0, 1, 0), "gpu0", "cpu0"),
        ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
    ],
)
@pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_copy_empty(shape, device_src, device_dst, is_symbolic):
    inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)

    def func(inp):
        return F.copy(inp, device_dst)

    if is_symbolic is not None:
        func = trace(symbolic=is_symbolic)(func)

    for _ in range(3):
        out = func(inp)
        assert out.numpy().shape == shape
        assert out.device == device_dst
        if is_symbolic is None:
            break


791 792 793 794 795 796 797 798 799 800 801 802
@pytest.mark.parametrize(
    "shape, repeats, axis",
    [
        ((2,), 2, 0),
        ((2, 3, 4, 5), 3, 0),
        ((2, 3, 4, 5), 4, 3),
        ((2,), 2, None),
        ((2, 3, 4, 5), 3, None),
        ((), 1, None),
        ((), 10, None),
    ],
)
803 804 805 806 807 808 809
@pytest.mark.parametrize("is_varnode", [True, False])
def test_repeat(shape, repeats, axis, is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

810 811 812 813 814 815 816 817 818 819 820
    def repeat_func(inp):
        return F.repeat(inp=inp, repeats=repeats, axis=axis)

    if shape != ():
        cases = [
            {"input": np.random.randn(*shape).astype("float32")},
        ]
    else:
        cases = [{"input": np.array(1.23)}]

    opr_test(
821 822 823 824
        cases,
        repeat_func,
        ref_fn=lambda inp: np.repeat(inp, repeats, axis),
        network=network,
825 826 827 828 829 830 831 832 833
    )


@pytest.mark.parametrize(
    "shape, reps",
    [
        ((2,), (2,)),
        ((2, 3, 4, 5), (1, 1, 1, 1)),
        ((2, 3, 4, 5), (1, 2, 3, 4)),
834 835
        # FIXME: tile does not support ndim 7
        # ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
836 837
    ],
)
838 839 840 841 842 843 844
@pytest.mark.parametrize("is_varnode", [True])
def test_tile(shape, reps, is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

845 846 847
    def tile_func(inp):
        return F.tile(inp=inp, reps=reps)

848
    cases = [{"input": np.random.randn(*shape).astype("float32")}]
849

850
    opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
851 852 853 854 855 856 857


@pytest.mark.parametrize(
    "shape, shifts, axis",
    [
        ((2, 3), 0, None),
        ((2, 3), 1, 0),
858 859
        ((2, 3), 100, 0),
        ((2, 3), -100, 0),
860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
        ((2, 3, 4, 5), (-1, 1), (0, 1)),
        ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
    ],
)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_roll(shape, shifts, axis, is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    inp = np.random.randn(*shape).astype("float32")

    def func(inp):
        return F.roll(inp, shifts, axis)

    cases = [
        {"input": inp},
    ]

    opr_test(
        cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
    )
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903


@pytest.mark.parametrize(
    "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
)
@pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
    inp = tensor(np.random.randn(*shape).astype("float32"))

    def func(inp):
        return F.roll(inp, shifts, axis)

    if is_symbolic is not None:
        func = trace(symbolic=is_symbolic)(func)

    out_ref = np.roll(inp.numpy(), shifts, axis)
    for _ in range(3):
        out = F.roll(inp, shifts, axis)
        np.testing.assert_equal(out.numpy(), out_ref)
        if is_symbolic is None:
            break