提交 b80fade3 编写于 作者: M Megvii Engine Team

fix(mge/core): avoid create RawTensor with zero-stride numpy ndarray

GitOrigin-RevId: a9b2940bdc48a4bfec297edeefc330821ac13797
上级 a398d4b5
......@@ -100,6 +100,8 @@ def _(data: DeviceTensorND):
@as_raw_tensor.register(np.ndarray)
def _(array: np.ndarray, dtype=None, device=None):
device = None if device is None else as_device(device).to_c()
if 0 in array.strides:
array = array.squeeze().reshape(array.shape)
return RawTensor(put(array, dtype=dtype, device=device))
......
......@@ -458,6 +458,15 @@ def test_conv_bias():
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")
def test_zero_stride_numpy_array():
inp = np.random.randn(3, 224, 224).astype(np.float32)
inp = inp[np.newaxis, :]
inp = tensor(inp, dtype=np.float32)
weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
def test_condtake():
x = np.array([[1, 2, 3], [4, 5, 6]])
y = np.array([[True, False, True], [False, True, True]])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册