Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
88c192c8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
88c192c8
编写于
12月 28, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(lite): add get_data_by_share in megenginelite python interface
GitOrigin-RevId: 0ddbb75e823106a61d5802d8e395db99a3e9f1d6
上级
8624ec22
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
224 addition
and
39 deletion
+224
-39
lite/pylite/megenginelite/__init__.py
lite/pylite/megenginelite/__init__.py
+1
-0
lite/pylite/megenginelite/network.py
lite/pylite/megenginelite/network.py
+69
-15
lite/pylite/megenginelite/tensor.py
lite/pylite/megenginelite/tensor.py
+71
-19
lite/pylite/test/test_network.py
lite/pylite/test/test_network.py
+52
-5
lite/pylite/test/test_tensor.py
lite/pylite/test/test_tensor.py
+31
-0
未找到文件。
lite/pylite/megenginelite/__init__.py
浏览文件 @
88c192c8
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.base
import
*
from
.base
import
*
from
.base
import
version
as
__version__
from
.global_setting
import
*
from
.global_setting
import
*
from
.network
import
*
from
.network
import
*
from
.struct
import
*
from
.struct
import
*
...
...
lite/pylite/megenginelite/network.py
浏览文件 @
88c192c8
...
@@ -69,7 +69,9 @@ class LiteOptions(Structure):
...
@@ -69,7 +69,9 @@ class LiteOptions(Structure):
"const_shape"
:
bool
(
self
.
const_shape
),
"const_shape"
:
bool
(
self
.
const_shape
),
"force_dynamic_alloc"
:
bool
(
self
.
force_dynamic_alloc
),
"force_dynamic_alloc"
:
bool
(
self
.
force_dynamic_alloc
),
"force_output_dynamic_alloc"
:
bool
(
self
.
force_output_dynamic_alloc
),
"force_output_dynamic_alloc"
:
bool
(
self
.
force_output_dynamic_alloc
),
"force_output_nocopy"
:
bool
(
self
.
force_output_nocopy
),
"force_output_use_user_specified_memory"
:
bool
(
self
.
force_output_use_user_specified_memory
),
"no_profiling_on_shape_change"
:
bool
(
self
.
no_profiling_on_shape_change
),
"no_profiling_on_shape_change"
:
bool
(
self
.
no_profiling_on_shape_change
),
"jit_level"
:
self
.
jit_level
,
"jit_level"
:
self
.
jit_level
,
"comp_node_seq_record_level"
:
self
.
comp_node_seq_record_level
,
"comp_node_seq_record_level"
:
self
.
comp_node_seq_record_level
,
...
@@ -99,7 +101,7 @@ class LiteConfig(Structure):
...
@@ -99,7 +101,7 @@ class LiteConfig(Structure):
(
"device_id"
,
c_int
),
(
"device_id"
,
c_int
),
(
"device_type"
,
c_int
),
(
"device_type"
,
c_int
),
(
"backend"
,
c_int
),
(
"backend"
,
c_int
),
(
"bare_model_cryption_name"
,
c_char_p
),
(
"
_
bare_model_cryption_name"
,
c_char_p
),
(
"options"
,
LiteOptions
),
(
"options"
,
LiteOptions
),
]
]
...
@@ -110,18 +112,30 @@ class LiteConfig(Structure):
...
@@ -110,18 +112,30 @@ class LiteConfig(Structure):
else
:
else
:
self
.
options
=
LiteOptions
()
self
.
options
=
LiteOptions
()
self
.
bare_model_cryption_name
=
c_char_p
(
b
""
)
self
.
_
bare_model_cryption_name
=
c_char_p
(
b
""
)
self
.
use_loader_dynamic_param
=
0
self
.
use_loader_dynamic_param
=
0
self
.
has_compression
=
0
self
.
has_compression
=
0
self
.
backend
=
LiteBackend
.
LITE_DEFAULT
self
.
backend
=
LiteBackend
.
LITE_DEFAULT
@
property
def
bare_model_cryption_name
(
self
):
return
self
.
_bare_model_cryption_name
.
decode
(
"utf-8"
)
@
bare_model_cryption_name
.
setter
def
bare_model_cryption_name
(
self
,
name
):
if
isinstance
(
name
,
str
):
self
.
_bare_model_cryption_name
=
name
.
encode
(
"utf-8"
)
else
:
assert
isinstance
(
name
,
bytes
),
"name should be str or bytes type."
self
.
_bare_model_cryption_name
=
name
def
__repr__
(
self
):
def
__repr__
(
self
):
data
=
{
data
=
{
"has_compression"
:
bool
(
self
.
has_compression
),
"has_compression"
:
bool
(
self
.
has_compression
),
"device_id"
:
LiteDeviceType
(
self
.
device_id
),
"device_id"
:
LiteDeviceType
(
self
.
device_id
),
"device_type"
:
LiteDeviceType
(
self
.
device_type
),
"device_type"
:
LiteDeviceType
(
self
.
device_type
),
"backend"
:
LiteBackend
(
self
.
backend
),
"backend"
:
LiteBackend
(
self
.
backend
),
"bare_model_cryption_name"
:
self
.
bare_model_cryption_name
.
decode
(
"utf-8"
)
,
"bare_model_cryption_name"
:
self
.
bare_model_cryption_name
,
"options"
:
self
.
options
,
"options"
:
self
.
options
,
}
}
return
data
.
__repr__
()
return
data
.
__repr__
()
...
@@ -149,7 +163,7 @@ class LiteIO(Structure):
...
@@ -149,7 +163,7 @@ class LiteIO(Structure):
"""
"""
_fields_
=
[
_fields_
=
[
(
"name"
,
c_char_p
),
(
"
_
name"
,
c_char_p
),
(
"is_host"
,
c_int
),
(
"is_host"
,
c_int
),
(
"io_type"
,
c_int
),
(
"io_type"
,
c_int
),
(
"config_layout"
,
LiteLayout
),
(
"config_layout"
,
LiteLayout
),
...
@@ -159,9 +173,9 @@ class LiteIO(Structure):
...
@@ -159,9 +173,9 @@ class LiteIO(Structure):
self
,
name
,
is_host
=
True
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
,
layout
=
None
self
,
name
,
is_host
=
True
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
,
layout
=
None
):
):
if
type
(
name
)
==
str
:
if
type
(
name
)
==
str
:
self
.
name
=
c_char_p
(
name
.
encode
(
"utf-8"
))
self
.
_
name
=
c_char_p
(
name
.
encode
(
"utf-8"
))
else
:
else
:
self
.
name
=
c_char_p
(
name
)
self
.
_
name
=
c_char_p
(
name
)
if
layout
:
if
layout
:
self
.
config_layout
=
layout
self
.
config_layout
=
layout
...
@@ -171,6 +185,18 @@ class LiteIO(Structure):
...
@@ -171,6 +185,18 @@ class LiteIO(Structure):
self
.
is_host
=
is_host
self
.
is_host
=
is_host
self
.
io_type
=
io_type
self
.
io_type
=
io_type
@
property
def
name
(
self
):
return
self
.
_name
.
decode
(
"utf-8"
)
@
name
.
setter
def
name
(
self
,
name
):
if
isinstance
(
name
,
str
):
self
.
_name
=
name
.
encode
(
"utf-8"
)
else
:
assert
isinstance
(
name
,
bytes
),
"name should be str or bytes type."
self
.
_name
=
name
def
__repr__
(
self
):
def
__repr__
(
self
):
data
=
{
data
=
{
"name"
:
self
.
name
,
"name"
:
self
.
name
,
...
@@ -208,17 +234,45 @@ class LiteNetworkIO(object):
...
@@ -208,17 +234,45 @@ class LiteNetworkIO(object):
the input and output information for user to construct _LiteNetWorkIO
the input and output information for user to construct _LiteNetWorkIO
"""
"""
def
__init__
(
self
):
def
__init__
(
self
,
inputs
=
None
,
outputs
=
None
):
self
.
inputs
=
[]
self
.
inputs
=
[]
self
.
outputs
=
[]
self
.
outputs
=
[]
if
inputs
:
for
i
in
inputs
:
if
isinstance
(
i
,
list
):
self
.
inputs
.
append
(
LiteIO
(
*
i
))
else
:
assert
isinstance
(
i
,
LiteIO
),
"the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
self
.
inputs
.
append
(
i
)
if
outputs
:
for
i
in
outputs
:
if
isinstance
(
i
,
list
):
self
.
outputs
.
append
(
LiteIO
(
*
i
))
else
:
assert
isinstance
(
i
,
LiteIO
),
"the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
self
.
outputs
.
append
(
i
)
def
add_input
(
self
,
input_io
):
def
add_input
(
assert
isinstance
(
input_io
,
LiteIO
)
self
,
obj
,
is_host
=
True
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
,
layout
=
None
self
.
inputs
.
append
(
input_io
)
):
if
isinstance
(
obj
,
LiteIO
):
self
.
inputs
.
append
(
obj
)
else
:
name
=
obj
self
.
add_input
(
LiteIO
(
name
,
is_host
,
io_type
,
layout
))
def
add_output
(
self
,
output_io
):
def
add_output
(
assert
isinstance
(
output_io
,
LiteIO
)
self
,
obj
,
is_host
=
True
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
,
layout
=
None
self
.
outputs
.
append
(
output_io
)
):
if
isinstance
(
obj
,
LiteIO
):
self
.
outputs
.
append
(
obj
)
else
:
name
=
obj
self
.
add_output
(
LiteIO
(
name
,
is_host
,
io_type
,
layout
))
def
_create_network_io
(
self
):
def
_create_network_io
(
self
):
network_io
=
_LiteNetworkIO
()
network_io
=
_LiteNetworkIO
()
...
...
lite/pylite/megenginelite/tensor.py
浏览文件 @
88c192c8
...
@@ -48,6 +48,15 @@ ctype_to_lite_dtypes = {
...
@@ -48,6 +48,15 @@ ctype_to_lite_dtypes = {
c_ushort
:
LiteDataType
.
LITE_UINT16
,
c_ushort
:
LiteDataType
.
LITE_UINT16
,
}
}
_lite_dtypes_to_ctype
=
{
LiteDataType
.
LITE_INT
:
c_int
,
LiteDataType
.
LITE_FLOAT
:
c_float
,
LiteDataType
.
LITE_UINT8
:
c_ubyte
,
LiteDataType
.
LITE_INT8
:
c_byte
,
LiteDataType
.
LITE_INT16
:
c_short
,
LiteDataType
.
LITE_UINT16
:
c_ushort
,
}
class
LiteLayout
(
Structure
):
class
LiteLayout
(
Structure
):
"""
"""
...
@@ -55,7 +64,7 @@ class LiteLayout(Structure):
...
@@ -55,7 +64,7 @@ class LiteLayout(Structure):
"""
"""
_fields_
=
[
_fields_
=
[
(
"shapes"
,
c_size_t
*
MAX_DIM
),
(
"
_
shapes"
,
c_size_t
*
MAX_DIM
),
(
"ndim"
,
c_size_t
),
(
"ndim"
,
c_size_t
),
(
"data_type"
,
c_int
),
(
"data_type"
,
c_int
),
]
]
...
@@ -64,10 +73,10 @@ class LiteLayout(Structure):
...
@@ -64,10 +73,10 @@ class LiteLayout(Structure):
if
shape
:
if
shape
:
shape
=
list
(
shape
)
shape
=
list
(
shape
)
assert
len
(
shape
)
<=
MAX_DIM
,
"Layout max dim is 7."
assert
len
(
shape
)
<=
MAX_DIM
,
"Layout max dim is 7."
self
.
shapes
=
(
c_size_t
*
MAX_DIM
)(
*
shape
)
self
.
_
shapes
=
(
c_size_t
*
MAX_DIM
)(
*
shape
)
self
.
ndim
=
len
(
shape
)
self
.
ndim
=
len
(
shape
)
else
:
else
:
self
.
shapes
=
(
c_size_t
*
MAX_DIM
)()
self
.
_
shapes
=
(
c_size_t
*
MAX_DIM
)()
self
.
ndim
=
0
self
.
ndim
=
0
if
not
dtype
:
if
not
dtype
:
self
.
data_type
=
LiteDataType
.
LITE_FLOAT
self
.
data_type
=
LiteDataType
.
LITE_FLOAT
...
@@ -83,9 +92,24 @@ class LiteLayout(Structure):
...
@@ -83,9 +92,24 @@ class LiteLayout(Structure):
else
:
else
:
raise
RuntimeError
(
"unkonw data type"
)
raise
RuntimeError
(
"unkonw data type"
)
@
property
def
dtype
(
self
):
return
_lite_type_to_nptypes
[
LiteDataType
(
self
.
data_type
)]
@
property
def
shapes
(
self
):
return
list
(
self
.
_shapes
)[
0
:
self
.
ndim
]
@
shapes
.
setter
def
shapes
(
self
,
shape
):
shape
=
list
(
shape
)
assert
len
(
shape
)
<=
MAX_DIM
,
"Layout max dim is 7."
self
.
_shapes
=
(
c_size_t
*
MAX_DIM
)(
*
shape
)
self
.
ndim
=
len
(
shape
)
def
__repr__
(
self
):
def
__repr__
(
self
):
data
=
{
data
=
{
"shapes"
:
list
(
self
.
shapes
)[
0
:
self
.
ndim
]
,
"shapes"
:
self
.
shapes
,
"ndim"
:
self
.
ndim
,
"ndim"
:
self
.
ndim
,
"data_type"
:
_lite_type_to_nptypes
[
LiteDataType
(
self
.
data_type
)],
"data_type"
:
_lite_type_to_nptypes
[
LiteDataType
(
self
.
data_type
)],
}
}
...
@@ -177,15 +201,20 @@ class LiteTensor(object):
...
@@ -177,15 +201,20 @@ class LiteTensor(object):
device_type
=
LiteDeviceType
.
LITE_CPU
,
device_type
=
LiteDeviceType
.
LITE_CPU
,
device_id
=
0
,
device_id
=
0
,
is_pinned_host
=
False
,
is_pinned_host
=
False
,
shapes
=
None
,
dtype
=
None
,
):
):
"""
"""
create a Tensor with layout, device, is_pinned_host param
create a Tensor with layout, device, is_pinned_host or shapes, dtype,
device_type, device_id, is_pinned_host param
"""
"""
self
.
_tensor
=
_Ctensor
()
self
.
_tensor
=
_Ctensor
()
if
layout
:
self
.
_layout
=
layout
else
:
self
.
_layout
=
LiteLayout
()
self
.
_layout
=
LiteLayout
()
if
layout
is
not
None
:
self
.
_layout
=
layout
elif
shapes
is
not
None
:
shapes
=
list
(
shapes
)
self
.
_layout
=
LiteLayout
(
shapes
,
dtype
)
self
.
_device_type
=
device_type
self
.
_device_type
=
device_type
self
.
_device_id
=
device_id
self
.
_device_id
=
device_id
self
.
_is_pinned_host
=
is_pinned_host
self
.
_is_pinned_host
=
is_pinned_host
...
@@ -222,9 +251,12 @@ class LiteTensor(object):
...
@@ -222,9 +251,12 @@ class LiteTensor(object):
@
layout
.
setter
@
layout
.
setter
def
layout
(
self
,
layout
):
def
layout
(
self
,
layout
):
assert
isinstance
(
layout
,
LiteLayout
)
if
isinstance
(
layout
,
LiteLayout
):
self
.
_layout
=
layout
self
.
_layout
=
layout
self
.
_api
.
LITE_set_tensor_layout
(
self
.
_tensor
,
layout
)
elif
isinstance
(
layout
,
list
):
self
.
_layout
.
shapes
=
layout
self
.
_api
.
LITE_set_tensor_layout
(
self
.
_tensor
,
self
.
_layout
)
@
property
@
property
def
is_pinned_host
(
self
):
def
is_pinned_host
(
self
):
...
@@ -270,7 +302,6 @@ class LiteTensor(object):
...
@@ -270,7 +302,6 @@ class LiteTensor(object):
"""
"""
get the length of the meomry in byte
get the length of the meomry in byte
"""
"""
self
.
update
()
length
=
c_size_t
()
length
=
c_size_t
()
self
.
_api
.
LITE_get_tensor_total_size_in_byte
(
self
.
_tensor
,
byref
(
length
))
self
.
_api
.
LITE_get_tensor_total_size_in_byte
(
self
.
_tensor
,
byref
(
length
))
return
length
.
value
return
length
.
value
...
@@ -336,7 +367,6 @@ class LiteTensor(object):
...
@@ -336,7 +367,6 @@ class LiteTensor(object):
"""
"""
get the memory of the tensor, return c_void_p of the tensor memory
get the memory of the tensor, return c_void_p of the tensor memory
"""
"""
self
.
update
()
mem
=
c_void_p
()
mem
=
c_void_p
()
self
.
_api
.
LITE_get_tensor_memory
(
self
.
_tensor
,
byref
(
mem
))
self
.
_api
.
LITE_get_tensor_memory
(
self
.
_tensor
,
byref
(
mem
))
return
mem
return
mem
...
@@ -347,7 +377,6 @@ class LiteTensor(object):
...
@@ -347,7 +377,6 @@ class LiteTensor(object):
param data: the data will shared to the tensor, it should be a
param data: the data will shared to the tensor, it should be a
numpy.ndarray or ctypes data
numpy.ndarray or ctypes data
"""
"""
self
.
update
()
if
isinstance
(
data
,
np
.
ndarray
):
if
isinstance
(
data
,
np
.
ndarray
):
assert
(
assert
(
self
.
is_continue
self
.
is_continue
...
@@ -356,8 +385,7 @@ class LiteTensor(object):
...
@@ -356,8 +385,7 @@ class LiteTensor(object):
self
.
is_pinned_host
or
self
.
device_type
==
LiteDeviceType
.
LITE_CPU
self
.
is_pinned_host
or
self
.
device_type
==
LiteDeviceType
.
LITE_CPU
),
"set_data_by_share can only apply in cpu tensor or pinned tensor."
),
"set_data_by_share can only apply in cpu tensor or pinned tensor."
np_type
=
_lite_type_to_nptypes
[
LiteDataType
(
self
.
_layout
.
data_type
)]
c_type
=
_lite_dtypes_to_ctype
[
LiteDataType
(
self
.
_layout
.
data_type
)]
c_type
=
np
.
ctypeslib
.
as_ctypes_type
(
np_type
)
if
self
.
nbytes
!=
data
.
nbytes
:
if
self
.
nbytes
!=
data
.
nbytes
:
self
.
layout
=
LiteLayout
(
data
.
shape
,
ctype_to_lite_dtypes
[
c_type
])
self
.
layout
=
LiteLayout
(
data
.
shape
,
ctype_to_lite_dtypes
[
c_type
])
...
@@ -377,7 +405,6 @@ class LiteTensor(object):
...
@@ -377,7 +405,6 @@ class LiteTensor(object):
param data: the data to copy to tensor, it should be list,
param data: the data to copy to tensor, it should be list,
numpy.ndarraya or ctypes with length
numpy.ndarraya or ctypes with length
"""
"""
self
.
update
()
if
layout
is
not
None
:
if
layout
is
not
None
:
self
.
layout
=
layout
self
.
layout
=
layout
...
@@ -386,8 +413,7 @@ class LiteTensor(object):
...
@@ -386,8 +413,7 @@ class LiteTensor(object):
self
.
is_pinned_host
or
self
.
device_type
==
LiteDeviceType
.
LITE_CPU
self
.
is_pinned_host
or
self
.
device_type
==
LiteDeviceType
.
LITE_CPU
),
"set_data_by_copy can only apply in cpu tensor or pinned tensor."
),
"set_data_by_copy can only apply in cpu tensor or pinned tensor."
np_type
=
_lite_type_to_nptypes
[
LiteDataType
(
self
.
_layout
.
data_type
)]
c_type
=
_lite_dtypes_to_ctype
[
LiteDataType
(
self
.
_layout
.
data_type
)]
c_type
=
np
.
ctypeslib
.
as_ctypes_type
(
np_type
)
tensor_memory
=
c_void_p
()
tensor_memory
=
c_void_p
()
...
@@ -415,6 +441,22 @@ class LiteTensor(object):
...
@@ -415,6 +441,22 @@ class LiteTensor(object):
self
.
_api
.
LITE_get_tensor_memory
(
self
.
_tensor
,
byref
(
tensor_memory
))
self
.
_api
.
LITE_get_tensor_memory
(
self
.
_tensor
,
byref
(
tensor_memory
))
memmove
(
tensor_memory
,
data
,
data_length
)
memmove
(
tensor_memory
,
data
,
data_length
)
def
get_data_by_share
(
self
):
"""
get the data in the tensor, add share the data with a new numpy, and
return the numpy arrray, be careful, the data in numpy is valid before
the tensor memory is write again, such as LiteNetwok forward next time.
"""
assert
self
.
is_continue
,
"get_data_by_share can only apply in continue tensor."
assert
(
self
.
is_pinned_host
or
self
.
device_type
==
LiteDeviceType
.
LITE_CPU
),
"get_data_by_share can only apply in CPU tensor or cpu pinned tensor."
memory
=
self
.
get_ctypes_memory
()
c_type
=
_lite_dtypes_to_ctype
[
LiteDataType
(
self
.
_layout
.
data_type
)]
pnt
=
cast
(
memory
,
POINTER
(
c_type
))
return
np
.
ctypeslib
.
as_array
(
pnt
,
self
.
_layout
.
shapes
)
def
to_numpy
(
self
):
def
to_numpy
(
self
):
"""
"""
get the buffer of the tensor
get the buffer of the tensor
...
@@ -475,3 +517,13 @@ def LiteTensorConcat(
...
@@ -475,3 +517,13 @@ def LiteTensorConcat(
)
)
result_tensor
.
update
()
result_tensor
.
update
()
return
result_tensor
return
result_tensor
def
lite_dtype_2_numpy
(
dtype
):
"""
convert lite dtype to corresponding numpy dtype
"""
assert
isinstance
(
dtype
,
LiteDataType
),
"input must be LiteDataType when using lite_dtype_2_numpy."
return
_lite_type_to_nptypes
[
dtype
]
lite/pylite/test/test_network.py
浏览文件 @
88c192c8
...
@@ -21,6 +21,12 @@ def test_version():
...
@@ -21,6 +21,12 @@ def test_version():
print
(
"Lite verson: {}"
.
format
(
version
))
print
(
"Lite verson: {}"
.
format
(
version
))
def
test_config
():
config
=
LiteConfig
()
config
.
bare_model_cryption_name
=
"nothing"
print
(
config
)
def
test_network_io
():
def
test_network_io
():
input_io1
=
LiteIO
(
"data1"
,
is_host
=
False
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
)
input_io1
=
LiteIO
(
"data1"
,
is_host
=
False
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
)
input_io2
=
LiteIO
(
input_io2
=
LiteIO
(
...
@@ -32,6 +38,7 @@ def test_network_io():
...
@@ -32,6 +38,7 @@ def test_network_io():
io
=
LiteNetworkIO
()
io
=
LiteNetworkIO
()
io
.
add_input
(
input_io1
)
io
.
add_input
(
input_io1
)
io
.
add_input
(
input_io2
)
io
.
add_input
(
input_io2
)
io
.
add_input
(
"data3"
,
False
)
output_io1
=
LiteIO
(
"out1"
,
is_host
=
False
)
output_io1
=
LiteIO
(
"out1"
,
is_host
=
False
)
output_io2
=
LiteIO
(
"out2"
,
is_host
=
True
,
layout
=
LiteLayout
([
1
,
1000
]))
output_io2
=
LiteIO
(
"out2"
,
is_host
=
True
,
layout
=
LiteLayout
([
1
,
1000
]))
...
@@ -39,7 +46,7 @@ def test_network_io():
...
@@ -39,7 +46,7 @@ def test_network_io():
io
.
add_output
(
output_io1
)
io
.
add_output
(
output_io1
)
io
.
add_output
(
output_io2
)
io
.
add_output
(
output_io2
)
assert
len
(
io
.
inputs
)
==
2
assert
len
(
io
.
inputs
)
==
3
assert
len
(
io
.
outputs
)
==
2
assert
len
(
io
.
outputs
)
==
2
assert
io
.
inputs
[
0
]
==
input_io1
assert
io
.
inputs
[
0
]
==
input_io1
...
@@ -47,9 +54,25 @@ def test_network_io():
...
@@ -47,9 +54,25 @@ def test_network_io():
c_io
=
io
.
_create_network_io
()
c_io
=
io
.
_create_network_io
()
assert
c_io
.
input_size
==
2
assert
c_io
.
input_size
==
3
assert
c_io
.
output_size
==
2
assert
c_io
.
output_size
==
2
ins
=
[[
"data1"
,
True
],
[
"data2"
,
False
,
LiteIOType
.
LITE_IO_SHAPE
]]
outs
=
[[
"out1"
,
True
],
[
"out2"
,
False
,
LiteIOType
.
LITE_IO_VALUE
]]
io2
=
LiteNetworkIO
(
ins
,
outs
)
assert
len
(
io2
.
inputs
)
==
2
assert
len
(
io2
.
outputs
)
==
2
io3
=
LiteNetworkIO
([
input_io1
,
input_io2
],
[
output_io1
,
output_io2
])
assert
len
(
io3
.
inputs
)
==
2
assert
len
(
io3
.
outputs
)
==
2
test_io
=
LiteIO
(
"test"
)
assert
test_io
.
name
==
"test"
test_io
.
name
=
"test2"
assert
test_io
.
name
==
"test2"
class
TestShuffleNet
(
unittest
.
TestCase
):
class
TestShuffleNet
(
unittest
.
TestCase
):
source_dir
=
os
.
getenv
(
"LITE_TEST_RESOURCE"
)
source_dir
=
os
.
getenv
(
"LITE_TEST_RESOURCE"
)
...
@@ -319,9 +342,9 @@ class TestNetwork(TestShuffleNet):
...
@@ -319,9 +342,9 @@ class TestNetwork(TestShuffleNet):
data
=
ios
[
key
].
to_numpy
().
flatten
()
data
=
ios
[
key
].
to_numpy
().
flatten
()
input_data
=
self
.
input_data
.
flatten
()
input_data
=
self
.
input_data
.
flatten
()
assert
data
.
size
==
input_data
.
size
assert
data
.
size
==
input_data
.
size
assert
io
.
name
.
decode
(
"utf-8"
)
==
"data"
assert
io
.
name
==
"data"
for
i
in
range
(
data
.
size
):
for
i
in
range
(
data
.
size
):
assert
data
[
i
]
==
input_data
[
i
]
assert
abs
(
data
[
i
]
-
input_data
[
i
])
<
1e-5
return
0
return
0
network
.
set_start_callback
(
start_callback
)
network
.
set_start_callback
(
start_callback
)
...
@@ -343,7 +366,7 @@ class TestNetwork(TestShuffleNet):
...
@@ -343,7 +366,7 @@ class TestNetwork(TestShuffleNet):
output_data
=
self
.
correct_data
.
flatten
()
output_data
=
self
.
correct_data
.
flatten
()
assert
data
.
size
==
output_data
.
size
assert
data
.
size
==
output_data
.
size
for
i
in
range
(
data
.
size
):
for
i
in
range
(
data
.
size
):
assert
data
[
i
]
==
output_data
[
i
]
assert
abs
(
data
[
i
]
-
output_data
[
i
])
<
1e-5
return
0
return
0
network
.
set_finish_callback
(
finish_callback
)
network
.
set_finish_callback
(
finish_callback
)
...
@@ -404,3 +427,27 @@ class TestNetwork(TestShuffleNet):
...
@@ -404,3 +427,27 @@ class TestNetwork(TestShuffleNet):
binary_equal_between_batch
=
True
,
binary_equal_between_batch
=
True
,
)
)
self
.
do_forward
(
network
)
self
.
do_forward
(
network
)
def
test_device_tensor_no_copy
(
self
):
# construct LiteOption
net_config
=
LiteConfig
()
net_config
.
options
.
force_output_use_user_specified_memory
=
True
network
=
LiteNetwork
(
config
=
net_config
)
network
.
load
(
self
.
model_path
)
input_tensor
=
network
.
get_io_tensor
(
"data"
)
# fill input_data with device data
input_tensor
.
set_data_by_share
(
self
.
input_data
)
output_tensor
=
network
.
get_io_tensor
(
network
.
get_output_name
(
0
))
out_array
=
np
.
zeros
(
output_tensor
.
layout
.
shapes
,
output_tensor
.
layout
.
dtype
)
output_tensor
.
set_data_by_share
(
out_array
)
# inference
for
i
in
range
(
2
):
network
.
forward
()
network
.
wait
()
self
.
check_correct
(
out_array
)
lite/pylite/test/test_tensor.py
浏览文件 @
88c192c8
...
@@ -54,6 +54,16 @@ def test_tensor_make():
...
@@ -54,6 +54,16 @@ def test_tensor_make():
tensor
=
LiteTensor
(
layout
,
device_id
=
1
)
tensor
=
LiteTensor
(
layout
,
device_id
=
1
)
assert
tensor
.
device_id
==
1
assert
tensor
.
device_id
==
1
tensor
.
layout
=
[
8
,
14
]
assert
tensor
.
layout
.
shapes
[
0
]
==
8
assert
tensor
.
layout
.
shapes
[
1
]
==
14
assert
tensor
.
layout
.
data_type
==
LiteDataType
.
LITE_FLOAT
tensor_new
=
LiteTensor
(
shapes
=
[
1
,
3
,
224
],
dtype
=
np
.
int8
)
assert
tensor_new
.
layout
.
shapes
[
1
]
==
3
assert
tensor_new
.
layout
.
shapes
[
2
]
==
224
assert
tensor_new
.
layout
.
data_type
==
LiteDataType
.
LITE_INT8
def
test_tensor_set_data
():
def
test_tensor_set_data
():
layout
=
LiteLayout
([
2
,
16
],
"int8"
)
layout
=
LiteLayout
([
2
,
16
],
"int8"
)
...
@@ -292,3 +302,24 @@ def test_tensor_concat():
...
@@ -292,3 +302,24 @@ def test_tensor_concat():
for
i
in
range
(
128
):
for
i
in
range
(
128
):
index
=
j
*
128
+
i
index
=
j
*
128
+
i
assert
real_data
[
index
//
32
][
index
%
32
]
==
j
assert
real_data
[
index
//
32
][
index
%
32
]
==
j
def
test_tensor_get_memory_by_share
():
layout
=
LiteLayout
([
4
,
32
],
"int16"
)
tensor
=
LiteTensor
(
layout
)
assert
tensor
.
nbytes
==
4
*
32
*
2
arr
=
np
.
ones
([
4
,
32
],
"int16"
)
for
i
in
range
(
128
):
arr
[
i
//
32
][
i
%
32
]
=
i
tensor
.
set_data_by_copy
(
arr
)
test_data
=
tensor
.
get_data_by_share
()
real_data
=
tensor
.
to_numpy
()
for
i
in
range
(
128
):
assert
real_data
[
i
//
32
][
i
%
32
]
==
test_data
[
i
//
32
][
i
%
32
]
arr
[
1
][
18
]
=
5
arr
[
3
][
7
]
=
345
tensor
.
set_data_by_copy
(
arr
)
assert
test_data
[
1
][
18
]
==
5
assert
test_data
[
3
][
7
]
==
345
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录