提交 cbff4d7c 编写于 作者: M Megvii Engine Team

fix(imperative/ops): fix infer_output_attrs_fallible for reshape

GitOrigin-RevId: a93567d79abe501ac86fb6d8384019a9c0c34d06
上级 d1be3127
......@@ -122,6 +122,33 @@ def test_reshape():
np.testing.assert_equal(yy.numpy(), y)
def test_reshape_shape_inference():
x_shape_known = tensor([1, 2, 3, 4], dtype="float32")
x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum())
tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known)
tshp_known = astensor1d((2, 2), x_shape_known)
tshp_known_unspec = astensor1d((2, -1), x_shape_known)
def check_shape(output, target):
source = output.shape
if isinstance(source, tensor):
source = source.numpy()
np.testing.assert_equal(source, target)
def func(x, target_shape):
return x.reshape(target_shape)
cases = [
{"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
]
opr_test(cases, func, compare_fn=check_shape, test_trace=True)
def test_squeeze():
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
xx = tensor(x)
......
......@@ -115,9 +115,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
}
mgb_assert(
tshp.layout.ndim == 1,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually",
"target shape of Reshape expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim);
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
size_t target_ndim = tshp.layout.shape[0];
out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册