提交 7fadc16d 编写于 作者: M Megvii Engine Team

refactor(mge/functional): support tensor shape in interpolate and split

GitOrigin-RevId: 6430b64f010ea5d0ecb1caa59b6da0a1547552ae
上级 968f74ce
......@@ -31,7 +31,9 @@ def dtype_promotion(raw_inputs):
]
inputs = [i for i in raw_inputs if hasattr(i, "dtype")]
assert len(scalar_inputs + inputs) > 0
dtype = np.result_type(*inputs)
dtype = None
if len(inputs) > 0:
dtype = np.result_type(*inputs)
dtype_all = np.result_type(*(inputs + scalar_inputs))
assert (
dtype != np.float64 and dtype != np.int64
......
......@@ -10,8 +10,9 @@
import functools
from ..core.ops import builtin
from ..core.tensor import utils
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply
from ..device import get_default_device
from ..tensor import Tensor
__all__ = [
......@@ -76,11 +77,17 @@ __all__ = [
def _elwise(*args, mode):
op = builtin.Elemwise(mode=mode)
tensor_args = list(
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
)
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
args = utils.convert_inputs(first_arg, *args[1:])
else:
args = utils.convert_inputs(*args)
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
args = tuple(
map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args)
)
args = utils.convert_inputs(*args)
args = tuple(map(lambda x: x.astype("float32"), args))
(result,) = apply(op, *args)
return result
......
......@@ -1126,11 +1126,8 @@ def interpolate(
if mode == "LINEAR":
inp = add_axis(inp, 3)
if not isinstance(inp.shape, inp.__class__):
if len(inp.shape) != 4:
raise ValueError(
"shape of input tensor must correspond to the operartion mode"
)
if inp.ndim != 4:
raise ValueError("shape of input tensor must correspond to the operartion mode")
if size is None:
if scale_factor is None:
......
......@@ -317,7 +317,7 @@ def split(inp, nsplits_or_sections, axis=0):
def swapaxis(inp, src, dst):
if src == dst:
return inp
shape = [i for i in range(len(inp.shape))]
shape = [i for i in range(inp.ndim)]
shape[src] = dst
shape[dst] = src
return inp.transpose(shape)
......@@ -325,9 +325,11 @@ def split(inp, nsplits_or_sections, axis=0):
inp = swapaxis(inp, 0, axis)
if isinstance(nsplits_or_sections, int):
incr_step = math.ceil(inp.shape[0] / nsplits_or_sections)
while incr_step < inp.shape[0]:
sections.append(incr_step)
incr_step = ceil(inp.shape[0] / nsplits_or_sections)
nsplits = nsplits_or_sections
while nsplits > 0:
nsplits -= 1
sections.append(incr_step.astype("int32"))
incr_step += nsplits_or_sections
else:
sections = nsplits_or_sections
......
......@@ -19,13 +19,13 @@ def test_abs():
np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)),
)
# assertTensorClose(F.abs(-3.0), np.abs(np.float32(-3.0)))
assertTensorClose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0)))
def test_multiply():
# assertTensorClose(
# F.mul(-3.0, -4.0), np.multiply(np.float32(-3.0), np.float32(-4.0))
# )
assertTensorClose(
F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0))
)
assertTensorClose(
F.mul(tensor([3.0, 4.0]), 4.0).numpy(),
......
......@@ -194,9 +194,6 @@ def test_matmul():
def test_interpolate():
if use_tensor_shape(): # XXX: please fix me
return
def linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
......
......@@ -125,8 +125,6 @@ def test_stack():
def test_split():
if use_tensor_shape(): # XXX: please fix me
return
data = np.random.random((2, 3, 4, 5)).astype(np.float32)
mge_out1 = F.split(tensor(data), 2, axis=3)
mge_out2 = F.split(tensor(data), [3, 5], axis=3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册