提交 9cb3c07c 编写于 作者: M Megvii Engine Team

feat(mge/functional): add elemwise mode support string input

GitOrigin-RevId: 57be5cec7bd00510ee98868939942d15077a440b
上级 8fad00a1
......@@ -72,7 +72,27 @@ __all__ = [
]
class _ElemwiseMode(Elemwise.Mode):
@classmethod
def __normalize(cls, val):
if isinstance(val, str):
if not hasattr(cls, "__member_upper_dict__"):
cls.__member_upper_dict__ = {
k.upper(): v for k, v in cls.__members__.items()
}
val = cls.__member_upper_dict__.get(val.upper(), val)
return val
@classmethod
def convert(cls, val):
val = cls.__normalize(val)
if isinstance(val, cls):
return val
return cls(val)
def _elwise(*args, mode):
mode = _ElemwiseMode.convert(mode)
op = builtin.Elemwise(mode)
tensor_args = list(
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
......
......@@ -73,11 +73,9 @@ class Elemwise(Module):
* "NOT": bool unary: ~x
"""
_elemwise_mode_type = P.Elemwise.Mode
def __init__(self, method):
super().__init__()
self.method = self._elemwise_mode_type.convert(method)
self.method = method
def forward(self, *inps):
return _elwise(*inps, mode=self.method)
......@@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls(float_module.method.name)
return cls(float_module.method)
......@@ -33,4 +33,4 @@ class Elemwise(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.method.name, qat_module.get_activation_dtype())
return cls(qat_module.method, qat_module.get_activation_dtype())
......@@ -10,6 +10,7 @@ import numpy as np
import megengine.functional as F
from megengine import tensor
from megengine.functional.elemwise import _elwise
def test_abs():
......@@ -21,6 +22,17 @@ def test_abs():
np.testing.assert_allclose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0)))
def test_elemwise_mode_string():
np.testing.assert_allclose(
_elwise(tensor([-3.0, -4.0, -5.0]), mode="ABS").numpy(),
np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)),
)
np.testing.assert_allclose(
_elwise(-3.0, mode="ABS").numpy(), np.abs(np.float32(-3.0))
)
def test_multiply():
np.testing.assert_allclose(
F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0))
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine.functional as F
from megengine import tensor
from megengine.module import Elemwise
def test_module_elemwise():
def test_func(method, *inps):
elemwise = Elemwise(method)
outputs = elemwise(*inps)
return outputs.numpy()
x = np.random.rand(100).astype("float32")
y = np.random.rand(100).astype("float32")
x, y = tensor(x), tensor(y)
np.testing.assert_almost_equal(
test_func("H_SWISH", x), F.hswish(x).numpy(), decimal=6
)
np.testing.assert_almost_equal(
test_func("ADD", x, y), F.add(x, y).numpy(), decimal=6
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册