提交 a398d4b5 编写于 作者: M Megvii Engine Team

chore(mge/functional): remove duplicated code

GitOrigin-RevId: f9efea46cb996b1583d8b8d038a9edf21d7ac83c
上级 477820fe
...@@ -210,7 +210,8 @@ def _todo(*_): ...@@ -210,7 +210,8 @@ def _todo(*_):
def _expand_args(args): def _expand_args(args):
if len(args) == 1: if len(args) == 1:
if isinstance( if isinstance(
args[0], (collections.abc.Sequence, TensorBase, TensorWrapperBase) args[0],
(collections.abc.Sequence, TensorBase, TensorWrapperBase, np.ndarray),
): ):
args = args[0] args = args[0]
return args return args
......
...@@ -88,13 +88,6 @@ def _elwise(*args, mode): ...@@ -88,13 +88,6 @@ def _elwise(*args, mode):
return result return result
def _logical(*args, mode):
op = builtin.CondExecPredLogical(mode=mode)
args = utils.convert_inputs(*args)
(result,) = apply(op, *args)
return result
def _elemwise_multi_type(*args, mode, **kwargs): def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs) op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
args = utils.convert_inputs(*args) args = utils.convert_inputs(*args)
......
...@@ -19,6 +19,7 @@ from ..core.ops import builtin ...@@ -19,6 +19,7 @@ from ..core.ops import builtin
from ..core.ops._internal import param_defs as P from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _remove_axis
from ..core.tensor.utils import ( from ..core.tensor.utils import (
astensor1d, astensor1d,
convert_inputs, convert_inputs,
...@@ -231,9 +232,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: ...@@ -231,9 +232,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
[3. 4. 5.]]] [3. 4. 5.]]]
""" """
shape = astensor1d(shape, inp, dtype="int32", device=inp.device) return inp.broadcast(shape)
(result,) = apply(builtin.Broadcast(), inp, shape)
return result
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
...@@ -730,10 +729,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: ...@@ -730,10 +729,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
[1 0]] [1 0]]
""" """
op = builtin.Dimshuffle(pattern) return inp.transpose(pattern)
(inp,) = convert_inputs(inp)
(result,) = apply(op, inp)
return result
dimshuffle = transpose dimshuffle = transpose
...@@ -773,26 +769,7 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: ...@@ -773,26 +769,7 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
[10 11]]] [10 11]]]
""" """
if isinstance(target_shape, (TensorBase, TensorWrapperBase)): return inp.reshape(target_shape)
target_shape = target_shape.numpy()
target_shape = tuple(map(int, target_shape))
unspec_axis = None
for i, s in enumerate(target_shape):
if s < 0:
if s != -1:
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s))
if unspec_axis is not None:
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
unspec_axis = i
# TODO: device should be None (cpu)
(target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp)
if unspec_axis is None:
op = builtin.Reshape()
else:
op = builtin.Reshape(unspec_axis=unspec_axis)
(x,) = apply(op, inp, target_shape)
return x
AxisAddRemove = builtin.AxisAddRemove AxisAddRemove = builtin.AxisAddRemove
...@@ -915,25 +892,7 @@ def remove_axis( ...@@ -915,25 +892,7 @@ def remove_axis(
(1, 1, 2) (1, 1, 2)
""" """
Param = builtin.AxisAddRemove.Param return _remove_axis(inp, axis)
def get_axes():
if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1]
try:
return [int(axis)]
except (TypeError, ValueError):
pass
return list(map(int, axis))
axis = get_axes()
axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
axis = [a - i for i, a in enumerate(axis)]
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis))
op = builtin.AxisAddRemove(param=param)
(result,) = apply(op, inp)
return result
def linspace( def linspace(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册