Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0dac2a79
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0dac2a79
编写于
3月 27, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(lite/pylite): fix megenginelite can not load model from file object
GitOrigin-RevId: b3162f7a9690ead9913542e27d69babb3bd81906
上级
f70d644a
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
40 addition
and
5 deletion
+40
-5
lite/pylite/megenginelite/network.py
lite/pylite/megenginelite/network.py
+21
-5
lite/pylite/test/test_network.py
lite/pylite/test/test_network.py
+19
-0
未找到文件。
lite/pylite/megenginelite/network.py
浏览文件 @
0dac2a79
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
io
import
os
from
ctypes
import
*
from
ctypes
import
*
import
numpy
as
np
import
megfile
from
.base
import
_Cnetwork
,
_Ctensor
,
_lib
,
_LiteCObjBase
from
.base
import
_Cnetwork
,
_Ctensor
,
_lib
,
_LiteCObjBase
from
.struct
import
*
from
.struct
import
*
...
@@ -594,6 +596,7 @@ class LiteNetwork(object):
...
@@ -594,6 +596,7 @@ class LiteNetwork(object):
c_network_io
=
self
.
network_io
.
_create_network_io
()
c_network_io
=
self
.
network_io
.
_create_network_io
()
self
.
_api
.
LITE_make_network
(
byref
(
self
.
_network
),
self
.
config
,
c_network_io
)
self
.
_api
.
LITE_make_network
(
byref
(
self
.
_network
),
self
.
config
,
c_network_io
)
self
.
model_bytes
=
None
def
__repr__
(
self
):
def
__repr__
(
self
):
data
=
{
"config"
:
self
.
config
,
"IOs"
:
self
.
network_io
}
data
=
{
"config"
:
self
.
config
,
"IOs"
:
self
.
network_io
}
...
@@ -602,12 +605,25 @@ class LiteNetwork(object):
...
@@ -602,12 +605,25 @@ class LiteNetwork(object):
def
__del__
(
self
):
def
__del__
(
self
):
self
.
_api
.
LITE_destroy_network
(
self
.
_network
)
self
.
_api
.
LITE_destroy_network
(
self
.
_network
)
def
load
(
self
,
path
):
def
load
(
self
,
file
):
"""
"""
load network from given
path
load network from given
file or file object
"""
"""
c_path
=
c_char_p
(
path
.
encode
(
"utf-8"
))
if
isinstance
(
file
,
(
str
,
os
.
PathLike
)):
self
.
_api
.
LITE_load_model_from_path
(
self
.
_network
,
c_path
)
with
megfile
.
smart_open
(
file
,
"rb"
)
as
f
:
self
.
model_bytes
=
f
.
read
()
else
:
assert
isinstance
(
file
,
io
.
BufferedReader
),
"file must be BufferedReader when open!!"
self
.
model_bytes
=
file
.
read
()
self
.
model_bytes
=
io
.
BytesIO
(
self
.
model_bytes
)
length
=
self
.
model_bytes
.
getbuffer
().
nbytes
self
.
model_bytes
=
c_buffer
(
self
.
model_bytes
.
getvalue
())
cdata
=
cast
(
self
.
model_bytes
,
POINTER
(
c_void_p
))
self
.
_api
.
LITE_load_model_from_mem
(
self
.
_network
,
cdata
,
length
)
def
forward
(
self
):
def
forward
(
self
):
"""
"""
...
...
lite/pylite/test/test_network.py
浏览文件 @
0dac2a79
...
@@ -501,6 +501,25 @@ class TestNetwork(TestShuffleNet):
...
@@ -501,6 +501,25 @@ class TestNetwork(TestShuffleNet):
os
.
remove
(
fast_run_cache
)
os
.
remove
(
fast_run_cache
)
os
.
remove
(
global_layout_transform_model
)
os
.
remove
(
global_layout_transform_model
)
def
test_network_basic_mem
(
self
):
network
=
LiteNetwork
()
with
open
(
self
.
model_path
,
"rb"
)
as
file
:
network
.
load
(
file
)
input_name
=
network
.
get_input_name
(
0
)
input_tensor
=
network
.
get_io_tensor
(
input_name
)
output_name
=
network
.
get_output_name
(
0
)
output_tensor
=
network
.
get_io_tensor
(
output_name
)
assert
input_tensor
.
layout
.
shapes
[
0
]
==
1
assert
input_tensor
.
layout
.
shapes
[
1
]
==
3
assert
input_tensor
.
layout
.
shapes
[
2
]
==
224
assert
input_tensor
.
layout
.
shapes
[
3
]
==
224
assert
input_tensor
.
layout
.
data_type
==
LiteDataType
.
LITE_FLOAT
assert
input_tensor
.
layout
.
ndim
==
4
self
.
do_forward
(
network
)
class
TestDiscreteInputNet
(
unittest
.
TestCase
):
class
TestDiscreteInputNet
(
unittest
.
TestCase
):
source_dir
=
os
.
getenv
(
"LITE_TEST_RESOURCE"
)
source_dir
=
os
.
getenv
(
"LITE_TEST_RESOURCE"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录