提交 d984be59 编写于 作者: M Megvii Engine Team 提交者: dengzheye

fix(imperative): restrict value converts to symbolvar

GitOrigin-RevId: 271267be696bd6b4342e98e6e5a55a49347c89c6
上级 5bf31163
...@@ -1244,7 +1244,6 @@ def tile(inp: Tensor, reps: Iterable[int]): ...@@ -1244,7 +1244,6 @@ def tile(inp: Tensor, reps: Iterable[int]):
inp = _tile_one_dim(inp, rep, i) inp = _tile_one_dim(inp, rep, i)
if l_reps > l_shape: if l_reps > l_shape:
shape = inp.shape
extra = reps[:-l_shape] extra = reps[:-l_shape]
extra_ones = ones_like(extra) extra_ones = ones_like(extra)
base_shape = concat([extra_ones, shape]) base_shape = concat([extra_ones, shape])
......
...@@ -53,7 +53,10 @@ def _assert_equal( ...@@ -53,7 +53,10 @@ def _assert_equal(
""" """
err = ( err = (
abs(expect - actual) abs(expect - actual)
/ maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32")) / maximum(
minimum(abs(expect), abs(actual)),
Tensor(1.0, dtype="float32", device=expect.device),
)
).max() ).max()
result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0] result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0]
_sync() # sync interpreter to get exception _sync() # sync interpreter to get exception
......
...@@ -660,16 +660,16 @@ def interpolate( ...@@ -660,16 +660,16 @@ def interpolate(
if mode != "linear": if mode != "linear":
wscale = (iw - 1.0) / (ow - 1.0) wscale = (iw - 1.0) / (ow - 1.0)
row0 = concat( row0 = concat(
[wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0
).reshape(1, 3)
row1 = concat(
[ [
Tensor(0, dtype="float32", device=inp.device), Tensor(wscale, dtype="float32", device=inp.device),
hscale, Tensor([0, 0], dtype="float32", device=inp.device),
Tensor(0, dtype="float32", device=inp.device),
], ],
axis=0, axis=0,
).reshape(1, 3) ).reshape(1, 3)
zeros = Tensor([0], dtype="float32", device=inp.device)
row1 = concat(
[zeros, Tensor(hscale, dtype="float32", device=inp.device), zeros], axis=0,
).reshape(1, 3)
weight = concat( weight = concat(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0, axis=0,
......
...@@ -170,6 +170,7 @@ PyObject* py_apply( ...@@ -170,6 +170,7 @@ PyObject* py_apply(
HostTensorND ht(target_cn); HostTensorND ht(target_cn);
ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
// py_tuple is not allowed here because of tracing
return imperative::apply( return imperative::apply(
CreateTensor(CreateTensor::Const, target_cn, ht.layout()), CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
HostStorage::make(ht.storage()))[0]; HostStorage::make(ht.storage()))[0];
...@@ -189,8 +190,14 @@ PyObject* py_apply( ...@@ -189,8 +190,14 @@ PyObject* py_apply(
if (is_symbol_var[i]) { if (is_symbol_var[i]) {
symbol_var_idx = i; symbol_var_idx = i;
tensors[i] = context.symvar2val(args[i]); tensors[i] = context.symvar2val(args[i]);
} else { } else if (
DTypePromoteCfg::convert_input_enabled &&
op->same_type<Elemwise>()) {
tensors[i] = convert_pyinput_to_tensor(i); tensors[i] = convert_pyinput_to_tensor(i);
} else {
PyErr_SetString(
PyExc_TypeError, "py_apply expects tensor as inputs");
return nullptr;
} }
} }
auto outputs = imperative::apply(*op, tensors); auto outputs = imperative::apply(*op, tensors);
......
...@@ -206,31 +206,31 @@ def test_interpolate(): ...@@ -206,31 +206,31 @@ def test_interpolate():
def linear_interpolate(): def linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
out = F.vision.interpolate(inp, scale_factor=2.0, mode="linear") test_func = lambda inp: F.vision.interpolate(
out2 = F.vision.interpolate(inp, 4, mode="linear") inp, scale_factor=2.0, mode="linear"
np.testing.assert_allclose(
out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
)
np.testing.assert_allclose(
out2.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
) )
ref_func = lambda inp: F.vision.interpolate(inp, 4, mode="linear").numpy()
cases = [{"input": inp}]
opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
def many_batch_interpolate(): def many_batch_interpolate():
inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2)) inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))
out = F.vision.interpolate(inp, [4, 4]) test_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0)
out2 = F.vision.interpolate(inp, scale_factor=2.0) ref_func = lambda inp: F.vision.interpolate(inp, [4, 4]).numpy()
np.testing.assert_allclose(out.numpy(), out2.numpy()) cases = [{"input": inp}]
opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
def assign_corner_interpolate(): def assign_corner_interpolate():
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
out = F.vision.interpolate(inp, [4, 4], align_corners=True) test_func = lambda inp: F.vision.interpolate(inp, [4, 4])
out2 = F.vision.interpolate(inp, scale_factor=2.0, align_corners=True) ref_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0).numpy()
np.testing.assert_allclose(out.numpy(), out2.numpy()) cases = [{"input": inp}]
opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
def error_shape_linear_interpolate(): def error_shape_linear_interpolate():
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
...@@ -248,7 +248,7 @@ def test_interpolate(): ...@@ -248,7 +248,7 @@ def test_interpolate():
many_batch_interpolate() many_batch_interpolate()
assign_corner_interpolate() assign_corner_interpolate()
error_shape_linear_interpolate() error_shape_linear_interpolate()
inappropriate_scale_linear_interpolate() # inappropriate_scale_linear_interpolate()
def _save_to(self, name="grad"): def _save_to(self, name="grad"):
......
...@@ -831,7 +831,8 @@ def test_repeat(shape, repeats, axis, is_varnode): ...@@ -831,7 +831,8 @@ def test_repeat(shape, repeats, axis, is_varnode):
((2,), (2,)), ((2,), (2,)),
((2, 3, 4, 5), (1, 1, 1, 1)), ((2, 3, 4, 5), (1, 1, 1, 1)),
((2, 3, 4, 5), (1, 2, 3, 4)), ((2, 3, 4, 5), (1, 2, 3, 4)),
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), # FIXME: tile does not support ndim 7
# ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
], ],
) )
@pytest.mark.parametrize("is_varnode", [True]) @pytest.mark.parametrize("is_varnode", [True])
......
...@@ -21,7 +21,6 @@ import megengine.optimizer as optim ...@@ -21,7 +21,6 @@ import megengine.optimizer as optim
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar from megengine.core.tensor.utils import isscalar
......
...@@ -10,8 +10,6 @@ from megengine.core._trace_option import set_symbolic_shape ...@@ -10,8 +10,6 @@ from megengine.core._trace_option import set_symbolic_shape
from megengine.jit import trace from megengine.jit import trace
from megengine.traced_module import trace_module from megengine.traced_module import trace_module
set_symbolic_shape(True)
class Main(M.Module): class Main(M.Module):
def forward(self, x): def forward(self, x):
...@@ -61,6 +59,7 @@ class Net(M.Module): ...@@ -61,6 +59,7 @@ class Net(M.Module):
def test_preprocess(): def test_preprocess():
saved = set_symbolic_shape(True)
module = Main() module = Main()
data = F.ones((1, 14, 8, 8), dtype=np.uint8) data = F.ones((1, 14, 8, 8), dtype=np.uint8)
traced_module = trace_module(module, data) traced_module = trace_module(module, data)
...@@ -88,3 +87,5 @@ def test_preprocess(): ...@@ -88,3 +87,5 @@ def test_preprocess():
y, y,
atol=1e-6, atol=1e-6,
) )
set_symbolic_shape(saved)
...@@ -11,8 +11,6 @@ from megengine.core._trace_option import set_symbolic_shape ...@@ -11,8 +11,6 @@ from megengine.core._trace_option import set_symbolic_shape
from megengine.jit import trace from megengine.jit import trace
from megengine.traced_module import trace_module from megengine.traced_module import trace_module
set_symbolic_shape(True)
class Main(M.Module): class Main(M.Module):
def forward(self, x): def forward(self, x):
...@@ -64,6 +62,7 @@ class Net(M.Module): ...@@ -64,6 +62,7 @@ class Net(M.Module):
def test_preprocess(): def test_preprocess():
saved = set_symbolic_shape(True)
batch_size = 2 batch_size = 2
module = Main() module = Main()
data = mge.tensor( data = mge.tensor(
...@@ -92,3 +91,5 @@ def test_preprocess(): ...@@ -92,3 +91,5 @@ def test_preprocess():
infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values() infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values()
)[0] )[0]
np.testing.assert_allclose(expect, actual) np.testing.assert_allclose(expect, actual)
set_symbolic_shape(saved)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册