test_functional.py 14.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import itertools

import numpy as np
import pytest
13
from utils import opr_test
14

15
import megengine.core.ops.builtin as builtin
16 17
import megengine.core.tensor.dtype as dtype
import megengine.functional as F
M
Megvii Engine Team 已提交
18
from megengine import Parameter, Tensor, is_cuda_available, tensor
19
from megengine.core._trace_option import use_tensor_shape
20
from megengine.core.autodiff.grad import Grad
21
from megengine.core.tensor.utils import make_shape_tuple
22 23 24
from megengine.test import assertTensorClose


25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
def test_where():
    maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
    xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
    yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)

    maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
    xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
    yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)

    cases = [
        {"input": [maskv0, xv0, yv0]},
        {"input": [maskv1, xv1, yv1]},
    ]
    opr_test(cases, F.where, ref_fn=np.where)

    maskv2 = np.array([1, 1, 1], dtype=np.bool_)
    xv2 = np.array([1, 3, 2], dtype=np.float32)
    yv2 = np.array([5, 6, 9], dtype=np.float32)

    maskv3 = np.array([0, 0, 0], dtype=np.bool_)
    xv3 = np.array([1, 3, 2], dtype=np.float32)
    yv3 = np.array([5, 6, 9], dtype=np.float32)

    cases = [
        {"input": [maskv2, xv2, yv2]},
        {"input": [maskv3, xv3, yv3]},
    ]
    opr_test(cases, F.where, ref_fn=np.where)
53 54


55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
def test_dropout():
    data = tensor(np.ones(10, dtype=np.float32))
    out = F.dropout(data, 1.0 / 3.0, training=False)

    assert out.numpy().sum() >= 0.0


def test_matmul():
    shape1 = 3
    shape2 = 3
    shape3 = (3, 5)
    shape4 = (5, 6)
    data1 = np.random.random(shape1).astype("float32")
    data2 = np.random.random(shape2).astype("float32")
    data3 = np.random.random(shape3).astype("float32")
    data4 = np.random.random(shape4).astype("float32")

    cases = [
        {"input": [data1, data2]},
        {"input": [data2, data3]},
        {"input": [data3, data4]},
    ]
    opr_test(cases, F.matmul, ref_fn=np.matmul)

    batch_size = 10
    shape1 = (batch_size, 2, 3)
    shape2 = (batch_size, 3, 4)
    shape3 = (batch_size, 10, 4, 5)
    data1 = np.random.random(shape1).astype("float32")
    data2 = np.random.random(shape2).astype("float32")
    data3 = np.random.random(shape3).astype("float32")

    cases = [{"input": [data1, data2]}, {"input": [data2, data3]}]
    for i in range(0, batch_size):

        def compare_fn(x, y):
            x.numpy()[i, ...] == y

        opr_test(
            cases,
            F.matmul,
            compare_fn=compare_fn,
            ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]),
        )


101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
def test_interpolate():
    def linear_interpolate():
        inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))

        out = F.interpolate(inp, scale_factor=2.0, mode="LINEAR")
        out2 = F.interpolate(inp, 4, mode="LINEAR")

        assertTensorClose(
            out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
        )
        assertTensorClose(
            out2.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
        )

    def many_batch_interpolate():
        inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))

        out = F.interpolate(inp, [4, 4])
        out2 = F.interpolate(inp, scale_factor=2.0)

        assertTensorClose(out.numpy(), out2.numpy())

    def assign_corner_interpolate():
        inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))

        out = F.interpolate(inp, [4, 4], align_corners=True)
        out2 = F.interpolate(inp, scale_factor=2.0, align_corners=True)

        assertTensorClose(out.numpy(), out2.numpy())

    def error_shape_linear_interpolate():
        inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))

        with pytest.raises(ValueError):
            F.interpolate(inp, scale_factor=2.0, mode="LINEAR")

    def inappropriate_scale_linear_interpolate():
        inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))

        with pytest.raises(ValueError):
            F.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR")

    linear_interpolate()
    many_batch_interpolate()
    assign_corner_interpolate()
    error_shape_linear_interpolate()
    inappropriate_scale_linear_interpolate()


def _save_to(self, name="grad"):
    def callback(tensor, grad):
        setattr(self, name, grad)

    return callback


def _gen_roi_inp():
    inp_feat = np.random.randn(2, 32, 256, 256)
    rois = np.zeros((4, 5))
    rois[:, 0] = [0, 0, 1, 1]
    rois[:, 1:3] = np.random.rand(4, 2) * 100
    rois[:, 3:] = np.random.rand(4, 2) * 100 + 150

    inp_feat = tensor(inp_feat)
    rois = tensor(rois)
    return inp_feat, rois


def test_roi_align():
    inp_feat, rois = _gen_roi_inp()
    grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))

    output_shape = (7, 7)
    out_feat = F.roi_align(
        inp_feat,
        rois,
        output_shape=output_shape,
        mode="average",
        spatial_scale=1.0 / 4,
        sample_points=2,
        aligned=True,
    )
183 184 185 186 187
    assert make_shape_tuple(out_feat.shape) == (
        rois.shape[0],
        inp_feat.shape[1],
        *output_shape,
    )
188 189

    grad(out_feat, tensor(F.ones_like(out_feat)))
190
    assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
191 192 193 194 195 196 197 198 199


def test_roi_pooling():
    inp_feat, rois = _gen_roi_inp()
    grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
    output_shape = (7, 7)
    out_feat = F.roi_pooling(
        inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
    )
200 201 202 203 204
    assert make_shape_tuple(out_feat.shape) == (
        rois.shape[0],
        inp_feat.shape[1],
        *output_shape,
    )
205 206

    grad(out_feat, tensor(F.ones_like(out_feat)))
207
    assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
208 209


210 211 212 213
def test_one_hot():
    def onehot_low_dimension():
        inp = tensor(np.arange(1, 4, dtype=np.int32))
        out = F.one_hot(inp, num_classes=4)
214

215 216 217
        assertTensorClose(
            out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]
        )
218

219 220 221 222 223
    def onehot_high_dimension():
        arr = np.array(
            [[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]],
            dtype=np.int32,
        )
224

225 226
        inp = tensor(arr)
        out = F.one_hot(inp, 10)
227

228
        assertTensorClose(out.numpy(), np.eye(10, dtype=np.int32)[arr])
229

230 231
    onehot_low_dimension()
    onehot_high_dimension()
232 233 234 235 236


def test_add_update():
    shape = (2, 3)
    v = np.random.random(shape).astype(np.float32)
M
Megvii Engine Team 已提交
237
    b = Tensor(v)
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

    u = F.add_update(b, 1)
    assertTensorClose(u.numpy(), v + 1)
    u = F.add_update(b, 1)
    assertTensorClose(u.numpy(), v + 2)

    x = np.ones((2, 2), dtype=np.float32)
    y = x * 0.5
    dest = tensor(x)
    delta = tensor(y)
    r = F.add_update(dest, delta, alpha=0.9, beta=0.1, bias=0.1)
    assertTensorClose(r.numpy(), x * 0.9 + y * 0.1 + 0.1)


def test_add_update_params():
    b = np.random.random((2, 3)).astype(np.float32)
M
Megvii Engine Team 已提交
254
    y = Tensor(b)
255 256 257 258 259 260 261

    # @jit.trace
    def f(x):
        return F.add_update(y, x)

    f(np.zeros((2, 3)).astype(np.float32))

M
Megvii Engine Team 已提交
262
    z = Tensor(np.zeros((2, 3)).astype(np.float32))
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
    F.add_update(y, z, beta=0.1)

    res = f(np.ones((2, 3)).astype(np.float32))
    assertTensorClose(res.numpy(), b + 1)


def test_binary_cross_entropy():
    data1_shape = (2, 2)
    label1_shape = (2, 2)
    data2_shape = (2, 3)
    label2_shape = (2, 3)

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def compare_fn(x, y):
        assertTensorClose(x.numpy(), y, max_err=5e-4)

    np.random.seed(123)
    data1 = sigmoid(np.random.uniform(size=data1_shape).astype(np.float32))
    label1 = np.random.uniform(size=label1_shape).astype(np.float32)
    expect1 = np.array([0.6361], dtype=np.float32)

    np.random.seed(123)
    data2 = sigmoid(np.random.uniform(size=data2_shape).astype(np.float32))
    label2 = np.random.uniform(size=label2_shape).astype(np.float32)
    expect2 = np.array([0.6750], dtype=np.float32)

    cases = [
        {"input": [data1, label1], "output": expect1,},
        {"input": [data2, label2], "output": expect2,},
    ]
    opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn)


def test_hinge_loss():
    np.random.seed(123)
    # case with L1 norm
    cases = []
    for shape in [(2, 2), (2, 3)]:
        data = np.random.uniform(size=shape).astype(np.float32)
        label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
        expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
        cases.append({"input": [data, label], "output": expect})

    opr_test(cases, F.hinge_loss)

    # cases with L2 norm
    cases = []
    for shape in [(2, 2), (2, 3)]:
        data = np.random.uniform(size=shape).astype(np.float32)
        label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
        expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
        cases.append({"input": [data, label], "output": expect})

    def hinge_loss_with_l2_norm(pred, label):
        return F.hinge_loss(pred, label, "L2")

    opr_test(cases, hinge_loss_with_l2_norm)


def test_nms():
    x = np.array(
        [
            [0, 0, 100, 100],
            [10, 10, 100, 100],
            [50, 50, 100, 100],
            [100, 100, 150, 150],
        ],
        dtype=np.float32,
    )
    inp = tensor(x)
    scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
336
    result = F.nms(inp, scores=scores, iou_thresh=0.5)
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
    np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))


def test_batched_nms():
    x = np.array(
        [
            [0, 0, 100, 100],
            [0.5, 0.5, 1.5, 1.5],
            [20, 20, 100, 100],
            [0.5, 0.5, 1.0, 1.0],
            [10, 10, 100, 100],
            [0.5, 0.5, 1.0, 1.0],
        ],
        dtype=np.float32,
    )
    inp = tensor(x)
    scores = tensor([0.6, 0.9, 0.5, 0.6, 0.8, 0.7], dtype=np.float32)
    idxs = tensor([0, 1, 0, 1, 0, 1], dtype=np.int32)
355
    results = F.batched_nms(inp, scores=scores, idxs=idxs, iou_thresh=0.5)
356 357 358
    np.testing.assert_equal(results.numpy(), np.array([1, 4, 5], dtype=np.int32))


359
@pytest.mark.skip(reason="cuda does not support nchw int8")
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
def test_conv_bias():
    inp_scale = 1.5
    w_scale = 2.5
    outp_scale = 1.5
    inp_dtype = dtype.qint8(inp_scale)
    w_dtype = dtype.qint8(w_scale)
    b_dtype = dtype.qint32(inp_scale * w_scale)
    out_dtype = dtype.qint8(outp_scale)

    def run(
        N,
        IC,
        OC,
        IH,
        IW,
        KH,
        KW,
        PH,
        PW,
        SH,
        SW,
        has_bias=True,
        nonlinear_mode="IDENTITY",
    ):
        inp_v = np.random.normal(size=(N, IC, IH, IW))
        w_v = np.random.normal(size=(OC, IC, KW, KW))
        b_v = np.random.normal(size=(1, OC, 1, 1))
        inp_scale = dtype.get_scale(inp_dtype)
        w_scale = dtype.get_scale(w_dtype)
        b_scale = dtype.get_scale(b_dtype)

        inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
        wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
        bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)

        inp_int8 = tensor(inpv, dtype=inp_dtype)
        w_int8 = Parameter(wv, dtype=w_dtype)
        b_int32 = Parameter(bv, dtype=b_dtype)

        inp_fp32 = inp_int8.astype("float32")
        w_fp32 = w_int8.astype("float32")
        b_fp32 = b_int32.astype("float32")

        def convert_to_nchw4(var):
            var = F.reshape(
                var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
            )
407
            var = F.transpose(var, (0, 1, 3, 4, 2))
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
            return var

        def run_conv2d(inp, w, b):
            O = F.conv2d(
                inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
            )
            if nonlinear_mode == "RELU":
                return F.relu(O)
            else:
                return O

        def run_conv_bias(inp, w, b, format="NCHW"):
            b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
            if format == "NCHW4":
                inp = convert_to_nchw4(inp)
                w = convert_to_nchw4(w)
                b = convert_to_nchw4(b)
            return F.conv_bias_activation(
                inp,
                w,
                b,
                stride=(SH, SW),
                padding=(PH, PW),
                format=format,
                dtype=out_dtype,
                nonlinear_mode=nonlinear_mode,
            )

        format = "NCHW4" if is_cuda_available() else "NCHW"

        expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
        expected = expected.astype(out_dtype).astype("float32")
        result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
            "float32"
        )
        if format == "NCHW4":
444
            result = F.transpose(result, (0, 1, 4, 2, 3))
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
        expected = F.flatten(expected)
        result = F.flatten(result)
        assertTensorClose(result.numpy(), expected.numpy(), max_err=outp_scale)

    run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
    run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
    run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)

    run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
    run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
    run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)

    run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU")
    run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")


461 462 463 464 465 466 467 468 469
def test_zero_stride_numpy_array():
    inp = np.random.randn(3, 224, 224).astype(np.float32)
    inp = inp[np.newaxis, :]

    inp = tensor(inp, dtype=np.float32)
    weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
    out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)


470 471 472 473 474 475 476 477
def test_condtake():
    x = np.array([[1, 2, 3], [4, 5, 6]])
    y = np.array([[True, False, True], [False, True, True]])
    xx = tensor(x)
    yy = tensor(y)
    val, idx = F.cond_take(yy, xx)
    np.testing.assert_equal(val.numpy(), x[y])
    np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494


def test_condtake_is_same():
    op1 = builtin.CondTake()
    op2 = builtin.CondTake()
    assert op1 == op2


def test_nms_is_same():
    op1 = builtin.NMSKeep(0.7, 100)
    op2 = builtin.NMSKeep(0.7, 100)
    op3 = builtin.NMSKeep(0.8, 100)
    op4 = builtin.NMSKeep(0.7, 200)
    assert op1 == op2
    assert op1 != op3
    assert op1 != op4
    assert op3 != op4