test_elemwise.py 998 字节
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine.functional as F
from megengine import tensor
from megengine.module import Elemwise


def test_module_elemwise():
    def test_func(method, *inps):
        elemwise = Elemwise(method)
        outputs = elemwise(*inps)
        return outputs.numpy()

    x = np.random.rand(100).astype("float32")
    y = np.random.rand(100).astype("float32")
    x, y = tensor(x), tensor(y)
    np.testing.assert_almost_equal(
26
        test_func("h_swish", x), F.hswish(x).numpy(), decimal=6
27 28
    )
    np.testing.assert_almost_equal(
29
        test_func("add", x, y), F.add(x, y).numpy(), decimal=6
30
    )