提交 10a0349e 编写于 作者: M Megvii Engine Team

feat(lite): add assert log for set_data_by_share

and set_data_by_copy. pylite network input is not
correct when input np is not continuous

GitOrigin-RevId: 1bdeae970a1c3ab723c7bfec8fb4ee53f40e1cfb
上级 f5597d9a
......@@ -423,6 +423,9 @@ class LiteTensor(object):
numpy.ndarray or ctypes data
"""
if isinstance(data, np.ndarray):
assert data.flags[
"C_CONTIGUOUS"
], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_share"
assert (
self.is_continue
), "set_data_by_share can only apply in continue tensor."
......@@ -474,6 +477,9 @@ class LiteTensor(object):
self.copy_from(cpu_tensor)
elif type(data) == np.ndarray:
assert data.flags[
"C_CONTIGUOUS"
], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_copy"
self.layout = LiteLayout(data.shape, data.dtype)
cpu_tensor.layout = LiteLayout(data.shape, data.dtype)
cdata = data.ctypes.data_as(POINTER(c_type))
......
......@@ -4,6 +4,7 @@ import functools
import numpy as np
import pytest
from megenginelite import *
......@@ -89,6 +90,32 @@ def test_tensor_set_data():
assert real_data[1][3] == 20
def test_set_data_by_copy_not_continuous():
layout = LiteLayout()
tensor = LiteTensor(layout)
arr = np.arange(6).reshape(2, 3).astype(np.uint8).transpose(1, 0)
with pytest.raises(AssertionError):
tensor.set_data_by_copy(arr)
arr = np.ascontiguousarray(arr)
tensor.set_data_by_copy(arr)
def test_set_data_by_share_not_continuous():
layout = LiteLayout([2, 3], "int8")
tensor = LiteTensor(layout)
arr = np.arange(6).reshape(2, 3).astype(np.uint8).transpose(1, 0)
with pytest.raises(AssertionError):
tensor.set_data_by_share(arr, 2 * 3)
arr = np.ascontiguousarray(arr)
tensor.set_data_by_share(arr.ctypes.data, 2 * 3)
def test_fill_zero():
layout = LiteLayout([4, 8], "int16")
tensor1 = LiteTensor(layout)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册