提交 b111baf1 编写于 作者: M Megvii Engine Team

fix(mge/tensor): fix const target shape in reshape

GitOrigin-RevId: 7c04a9efbabc7c7e75d87c67a611a08a7371fe49
上级 4ab259f5
...@@ -74,9 +74,7 @@ def _reshape(x, shape): ...@@ -74,9 +74,7 @@ def _reshape(x, shape):
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
unspec_axis = i unspec_axis = i
if not isinstance(shape, (TensorBase, TensorWrapperBase)): shape = utils.astensor1d(shape, x, dtype="int32", device=x.device)
# TODO: device should be None (cpu)
(shape,) = Const(shape, dtype=np.int32, device=x.device)(x)
if unspec_axis is None: if unspec_axis is None:
op = builtin.Reshape() op = builtin.Reshape()
......
...@@ -266,3 +266,20 @@ def test_trace_cvt_bool(): ...@@ -266,3 +266,20 @@ def test_trace_cvt_bool():
for i in range(3): for i in range(3):
np.testing.assert_equal(f(x).numpy()[0], False) np.testing.assert_equal(f(x).numpy()[0], False)
def test_trace_reshape():
for symbolic in [False, True]:
set_tensor_shape(True)
x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10))
@trace(symbolic=symbolic, capture_as_const=True)
def f(x):
y = x.reshape(x.shape[0], 100)
return y
f(x1)
f(x2)
f(x3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册