提交 7fadc16d 编写于 作者: M Megvii Engine Team

refactor(mge/functional): support tensor shape in interpolate and split

GitOrigin-RevId: 6430b64f010ea5d0ecb1caa59b6da0a1547552ae
上级 968f74ce
...@@ -31,7 +31,9 @@ def dtype_promotion(raw_inputs): ...@@ -31,7 +31,9 @@ def dtype_promotion(raw_inputs):
] ]
inputs = [i for i in raw_inputs if hasattr(i, "dtype")] inputs = [i for i in raw_inputs if hasattr(i, "dtype")]
assert len(scalar_inputs + inputs) > 0 assert len(scalar_inputs + inputs) > 0
dtype = np.result_type(*inputs) dtype = None
if len(inputs) > 0:
dtype = np.result_type(*inputs)
dtype_all = np.result_type(*(inputs + scalar_inputs)) dtype_all = np.result_type(*(inputs + scalar_inputs))
assert ( assert (
dtype != np.float64 and dtype != np.int64 dtype != np.float64 and dtype != np.int64
......
...@@ -10,8 +10,9 @@ ...@@ -10,8 +10,9 @@
import functools import functools
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor import utils from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply from ..core.tensor.core import apply
from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
__all__ = [ __all__ = [
...@@ -76,11 +77,17 @@ __all__ = [ ...@@ -76,11 +77,17 @@ __all__ = [
def _elwise(*args, mode): def _elwise(*args, mode):
op = builtin.Elemwise(mode=mode) op = builtin.Elemwise(mode=mode)
tensor_args = list(
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
)
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
args = utils.convert_inputs(first_arg, *args[1:])
else:
args = utils.convert_inputs(*args)
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
args = tuple( args = tuple(map(lambda x: x.astype("float32"), args))
map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args)
)
args = utils.convert_inputs(*args)
(result,) = apply(op, *args) (result,) = apply(op, *args)
return result return result
......
...@@ -1126,11 +1126,8 @@ def interpolate( ...@@ -1126,11 +1126,8 @@ def interpolate(
if mode == "LINEAR": if mode == "LINEAR":
inp = add_axis(inp, 3) inp = add_axis(inp, 3)
if not isinstance(inp.shape, inp.__class__): if inp.ndim != 4:
if len(inp.shape) != 4: raise ValueError("shape of input tensor must correspond to the operartion mode")
raise ValueError(
"shape of input tensor must correspond to the operartion mode"
)
if size is None: if size is None:
if scale_factor is None: if scale_factor is None:
......
...@@ -317,7 +317,7 @@ def split(inp, nsplits_or_sections, axis=0): ...@@ -317,7 +317,7 @@ def split(inp, nsplits_or_sections, axis=0):
def swapaxis(inp, src, dst): def swapaxis(inp, src, dst):
if src == dst: if src == dst:
return inp return inp
shape = [i for i in range(len(inp.shape))] shape = [i for i in range(inp.ndim)]
shape[src] = dst shape[src] = dst
shape[dst] = src shape[dst] = src
return inp.transpose(shape) return inp.transpose(shape)
...@@ -325,9 +325,11 @@ def split(inp, nsplits_or_sections, axis=0): ...@@ -325,9 +325,11 @@ def split(inp, nsplits_or_sections, axis=0):
inp = swapaxis(inp, 0, axis) inp = swapaxis(inp, 0, axis)
if isinstance(nsplits_or_sections, int): if isinstance(nsplits_or_sections, int):
incr_step = math.ceil(inp.shape[0] / nsplits_or_sections) incr_step = ceil(inp.shape[0] / nsplits_or_sections)
while incr_step < inp.shape[0]: nsplits = nsplits_or_sections
sections.append(incr_step) while nsplits > 0:
nsplits -= 1
sections.append(incr_step.astype("int32"))
incr_step += nsplits_or_sections incr_step += nsplits_or_sections
else: else:
sections = nsplits_or_sections sections = nsplits_or_sections
......
...@@ -19,13 +19,13 @@ def test_abs(): ...@@ -19,13 +19,13 @@ def test_abs():
np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)), np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)),
) )
# assertTensorClose(F.abs(-3.0), np.abs(np.float32(-3.0))) assertTensorClose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0)))
def test_multiply(): def test_multiply():
# assertTensorClose( assertTensorClose(
# F.mul(-3.0, -4.0), np.multiply(np.float32(-3.0), np.float32(-4.0)) F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0))
# ) )
assertTensorClose( assertTensorClose(
F.mul(tensor([3.0, 4.0]), 4.0).numpy(), F.mul(tensor([3.0, 4.0]), 4.0).numpy(),
......
...@@ -194,9 +194,6 @@ def test_matmul(): ...@@ -194,9 +194,6 @@ def test_matmul():
def test_interpolate(): def test_interpolate():
if use_tensor_shape(): # XXX: please fix me
return
def linear_interpolate(): def linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
......
...@@ -125,8 +125,6 @@ def test_stack(): ...@@ -125,8 +125,6 @@ def test_stack():
def test_split(): def test_split():
if use_tensor_shape(): # XXX: please fix me
return
data = np.random.random((2, 3, 4, 5)).astype(np.float32) data = np.random.random((2, 3, 4, 5)).astype(np.float32)
mge_out1 = F.split(tensor(data), 2, axis=3) mge_out1 = F.split(tensor(data), 2, axis=3)
mge_out2 = F.split(tensor(data), [3, 5], axis=3) mge_out2 = F.split(tensor(data), [3, 5], axis=3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册