提交 8d2bbf73 编写于 作者: M Megvii Engine Team

fix(mge/functional): support scalar inputs in elemwise functions

GitOrigin-RevId: 7bce561ee1a48bfba744bde2c88939a0a3ffeabd
上级 ea00b57d
......@@ -11,6 +11,7 @@ import functools
import megengine._internal as mgb
from ..core.graph import _use_default_if_none
from ..core.tensor import Tensor, wrap_io_tensor
__all__ = [
......@@ -45,11 +46,17 @@ __all__ = [
def _elemwise(mode): # DONT export
"""Decorator helps to wrap megbrain element-wise oprs"""
def elemwise_decorator(func):
@functools.wraps(func)
@wrap_io_tensor
def elemwise_func(*inputs) -> Tensor:
if all(isinstance(i, (int,float)) for i in inputs):
device, comp_graph = _use_default_if_none(None, None)
ret = mgb.opr.elemwise(*inputs,
mode=mode,
comp_node=device,
comp_graph=comp_graph)
return ret.inferred_value[0]
return mgb.opr.elemwise(*inputs, mode=mode)
return elemwise_func
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine.functional as F
from megengine import tensor
from megengine.test import assertTensorClose
def test_abs():
assertTensorClose(
F.abs(tensor([-3., -4., -5.])).numpy(),
np.abs(np.array([-3., -4., -5.], dtype=np.float32)))
assertTensorClose(F.abs(-3.), np.abs(np.float32(-3.)))
def test_multiply():
assertTensorClose(F.multiply(-3., -4.),
np.multiply(np.float32(-3.), np.float32(-4.)))
assertTensorClose(
F.multiply(tensor([3., 4.]), 4.).numpy(),
np.multiply(np.array([3., 4.], dtype=np.float32), 4.))
assertTensorClose(
F.multiply(4., tensor([3., 4.])).numpy(),
np.multiply(4., np.array([3., 4.], dtype=np.float32)))
assertTensorClose(
F.multiply(tensor([3., 4.]), tensor([3., 4.])).numpy(),
np.multiply(np.array([3., 4.], dtype=np.float32),
np.array([3., 4.], dtype=np.float32)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册