提交 a398d4b5 编写于 作者: M Megvii Engine Team

chore(mge/functional): remove duplicated code

GitOrigin-RevId: f9efea46cb996b1583d8b8d038a9edf21d7ac83c
上级 477820fe
......@@ -210,7 +210,8 @@ def _todo(*_):
def _expand_args(args):
if len(args) == 1:
if isinstance(
args[0], (collections.abc.Sequence, TensorBase, TensorWrapperBase)
args[0],
(collections.abc.Sequence, TensorBase, TensorWrapperBase, np.ndarray),
):
args = args[0]
return args
......
......@@ -88,13 +88,6 @@ def _elwise(*args, mode):
return result
def _logical(*args, mode):
op = builtin.CondExecPredLogical(mode=mode)
args = utils.convert_inputs(*args)
(result,) = apply(op, *args)
return result
def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
args = utils.convert_inputs(*args)
......
......@@ -19,6 +19,7 @@ from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _remove_axis
from ..core.tensor.utils import (
astensor1d,
convert_inputs,
......@@ -231,9 +232,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
[3. 4. 5.]]]
"""
shape = astensor1d(shape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), inp, shape)
return result
return inp.broadcast(shape)
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
......@@ -730,10 +729,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
[1 0]]
"""
op = builtin.Dimshuffle(pattern)
(inp,) = convert_inputs(inp)
(result,) = apply(op, inp)
return result
return inp.transpose(pattern)
dimshuffle = transpose
......@@ -773,26 +769,7 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
[10 11]]]
"""
if isinstance(target_shape, (TensorBase, TensorWrapperBase)):
target_shape = target_shape.numpy()
target_shape = tuple(map(int, target_shape))
unspec_axis = None
for i, s in enumerate(target_shape):
if s < 0:
if s != -1:
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s))
if unspec_axis is not None:
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
unspec_axis = i
# TODO: device should be None (cpu)
(target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp)
if unspec_axis is None:
op = builtin.Reshape()
else:
op = builtin.Reshape(unspec_axis=unspec_axis)
(x,) = apply(op, inp, target_shape)
return x
return inp.reshape(target_shape)
AxisAddRemove = builtin.AxisAddRemove
......@@ -915,25 +892,7 @@ def remove_axis(
(1, 1, 2)
"""
Param = builtin.AxisAddRemove.Param
def get_axes():
if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1]
try:
return [int(axis)]
except (TypeError, ValueError):
pass
return list(map(int, axis))
axis = get_axes()
axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
axis = [a - i for i, a in enumerate(axis)]
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis))
op = builtin.AxisAddRemove(param=param)
(result,) = apply(op, inp)
return result
return _remove_axis(inp, axis)
def linspace(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册