test_module_tensor.py 2.6 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8 9 10 11 12 13 14 15
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy

import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
M
Megvii Engine Team 已提交
16
from megengine import Parameter, Tensor
17 18 19
from megengine.module import Conv2d


20
# TODO: delete this test after deleting set_value
21 22 23 24 25
def test_set_value():
    v0 = np.random.random((2, 3)).astype(np.float32)
    param = Parameter(v0)
    v1 = np.random.random((2, 3)).astype(np.float32)
    param.set_value(v1)
26
    np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)
27 28 29 30
    v2 = np.random.random((3, 3)).astype(np.float32)
    # TODO: add this
    # with pytest.raises(ValueError):
    #     param.set_value(v2)
31
    np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)
32 33 34 35


@pytest.mark.skip(reason="fill unsupported")
def test_fill():
M
Megvii Engine Team 已提交
36
    a = Tensor(np.zeros((2, 3), dtype=np.float32))
37
    a.fill(3)
38
    np.testing.assert_allclose(a.numpy(), np.full((2, 3), 3, dtype=np.float32))
39
    a.fill(124.568)
40
    np.testing.assert_allclose(a.numpy(), np.full((2, 3), 124.568, dtype=np.float32))
41 42 43 44 45 46 47 48 49 50 51 52 53


# TODO: remove or rewrite following test
# def test_attach():
#     p_ = np.random.random((2, 3)).astype(np.float32)

#     with Graph() as g:
#         g.set_option('eager_evaluation', False)
#         p = Parameter(p_)
#         v = p * 2
#         f = compile(v, None)

#     out, = f()
54
#     np.testing.assert_allclose(out, p_ * 2)
55 56 57

#     F.add_update(p, p)
#     out, = f()
58
#     np.testing.assert_allclose(out, p_ * 4)
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76


# TODO: remove or rewrite following test
# def test_module_attach():
#     v = np.random.random((1, 3, 64, 64)).astype(np.float32)
#     net = Conv2d(3, 16, 3)

#     with Graph() as g:
#         g.set_option('eager_evaluation', False)

#         data0 = Input("data")
#         f = compile(net(data0), None)

#     out0, = f(data=v)

#     data1 = Input("data", value=v)
#     out1 = net(data1)

77
#     np.testing.assert_allclose(out0, out1.numpy())
78 79 80 81 82


# def test_shape_warning():
#     with Graph() as cg:
#         cg.set_option("eager_evaluation", False)
M
Megvii Engine Team 已提交
83
#         b = Tensor(np.ones((2, 3)).astype(np.float32))
84 85 86 87 88 89
#         with pytest.warns(None) as record:
#             print(b.shape)
#         if len(record) != 0:
#             raise ValueError(
#                 "Getting the shape of a constant Tensor should throw no Warning"
#             )