From 9cb3c07cd4f8e33d5bfef96bb792a28f90a97b99 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 12 Nov 2020 11:26:17 +0800 Subject: [PATCH] feat(mge/functional): add elemwise mode support string input GitOrigin-RevId: 57be5cec7bd00510ee98868939942d15077a440b --- .../python/megengine/functional/elemwise.py | 20 +++++++++++++ .../python/megengine/module/elemwise.py | 4 +-- .../python/megengine/module/qat/elemwise.py | 2 +- .../megengine/module/quantized/elemwise.py | 2 +- .../test/unit/functional/test_elemwise.py | 12 ++++++++ .../python/test/unit/module/test_elemwise.py | 30 +++++++++++++++++++ 6 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 imperative/python/test/unit/module/test_elemwise.py diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index b9b477da1..98abebce9 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -72,7 +72,27 @@ __all__ = [ ] +class _ElemwiseMode(Elemwise.Mode): + @classmethod + def __normalize(cls, val): + if isinstance(val, str): + if not hasattr(cls, "__member_upper_dict__"): + cls.__member_upper_dict__ = { + k.upper(): v for k, v in cls.__members__.items() + } + val = cls.__member_upper_dict__.get(val.upper(), val) + return val + + @classmethod + def convert(cls, val): + val = cls.__normalize(val) + if isinstance(val, cls): + return val + return cls(val) + + def _elwise(*args, mode): + mode = _ElemwiseMode.convert(mode) op = builtin.Elemwise(mode) tensor_args = list( filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) diff --git a/imperative/python/megengine/module/elemwise.py b/imperative/python/megengine/module/elemwise.py index 9bc05fbfc..dfc697251 100644 --- a/imperative/python/megengine/module/elemwise.py +++ b/imperative/python/megengine/module/elemwise.py @@ -73,11 +73,9 @@ class Elemwise(Module): * "NOT": bool unary: ~x """ - _elemwise_mode_type = P.Elemwise.Mode - def __init__(self, method): super().__init__() - self.method = self._elemwise_mode_type.convert(method) + self.method = method def forward(self, *inps): return _elwise(*inps, mode=self.method) diff --git a/imperative/python/megengine/module/qat/elemwise.py b/imperative/python/megengine/module/qat/elemwise.py index f99583bde..162952819 100644 --- a/imperative/python/megengine/module/qat/elemwise.py +++ b/imperative/python/megengine/module/qat/elemwise.py @@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule): Return a :class:`~.QATModule` instance converted from a float :class:`~.Module` instance. """ - return cls(float_module.method.name) + return cls(float_module.method) diff --git a/imperative/python/megengine/module/quantized/elemwise.py b/imperative/python/megengine/module/quantized/elemwise.py index 5021be1aa..3b16f8cf3 100644 --- a/imperative/python/megengine/module/quantized/elemwise.py +++ b/imperative/python/megengine/module/quantized/elemwise.py @@ -33,4 +33,4 @@ class Elemwise(QuantizedModule): Return a :class:`~.QuantizedModule` instance converted from a :class:`~.QATModule` instance. """ - return cls(qat_module.method.name, qat_module.get_activation_dtype()) + return cls(qat_module.method, qat_module.get_activation_dtype()) diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 30421dd8c..3436f145b 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -10,6 +10,7 @@ import numpy as np import megengine.functional as F from megengine import tensor +from megengine.functional.elemwise import _elwise def test_abs(): @@ -21,6 +22,17 @@ def test_abs(): np.testing.assert_allclose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0))) +def test_elemwise_mode_string(): + np.testing.assert_allclose( + _elwise(tensor([-3.0, -4.0, -5.0]), mode="ABS").numpy(), + np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)), + ) + + np.testing.assert_allclose( + _elwise(-3.0, mode="ABS").numpy(), np.abs(np.float32(-3.0)) + ) + + def test_multiply(): np.testing.assert_allclose( F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0)) diff --git a/imperative/python/test/unit/module/test_elemwise.py b/imperative/python/test/unit/module/test_elemwise.py new file mode 100644 index 000000000..f9b5094fd --- /dev/null +++ b/imperative/python/test/unit/module/test_elemwise.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +import megengine.functional as F +from megengine import tensor +from megengine.module import Elemwise + + +def test_module_elemwise(): + def test_func(method, *inps): + elemwise = Elemwise(method) + outputs = elemwise(*inps) + return outputs.numpy() + + x = np.random.rand(100).astype("float32") + y = np.random.rand(100).astype("float32") + x, y = tensor(x), tensor(y) + np.testing.assert_almost_equal( + test_func("H_SWISH", x), F.hswish(x).numpy(), decimal=6 + ) + np.testing.assert_almost_equal( + test_func("ADD", x, y), F.add(x, y).numpy(), decimal=6 + ) -- GitLab