提交 0dac2a79 编写于 作者: M Megvii Engine Team

fix(lite/pylite): fix megenginelite can not load model from file object

GitOrigin-RevId: b3162f7a9690ead9913542e27d69babb3bd81906
上级 f70d644a
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import io
import os
from ctypes import * from ctypes import *
import numpy as np import megfile
from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
from .struct import * from .struct import *
...@@ -594,6 +596,7 @@ class LiteNetwork(object): ...@@ -594,6 +596,7 @@ class LiteNetwork(object):
c_network_io = self.network_io._create_network_io() c_network_io = self.network_io._create_network_io()
self._api.LITE_make_network(byref(self._network), self.config, c_network_io) self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
self.model_bytes = None
def __repr__(self): def __repr__(self):
data = {"config": self.config, "IOs": self.network_io} data = {"config": self.config, "IOs": self.network_io}
...@@ -602,12 +605,25 @@ class LiteNetwork(object): ...@@ -602,12 +605,25 @@ class LiteNetwork(object):
def __del__(self): def __del__(self):
self._api.LITE_destroy_network(self._network) self._api.LITE_destroy_network(self._network)
def load(self, path): def load(self, file):
""" """
load network from given path load network from given file or file object
""" """
c_path = c_char_p(path.encode("utf-8")) if isinstance(file, (str, os.PathLike)):
self._api.LITE_load_model_from_path(self._network, c_path) with megfile.smart_open(file, "rb") as f:
self.model_bytes = f.read()
else:
assert isinstance(
file, io.BufferedReader
), "file must be BufferedReader when open!!"
self.model_bytes = file.read()
self.model_bytes = io.BytesIO(self.model_bytes)
length = self.model_bytes.getbuffer().nbytes
self.model_bytes = c_buffer(self.model_bytes.getvalue())
cdata = cast(self.model_bytes, POINTER(c_void_p))
self._api.LITE_load_model_from_mem(self._network, cdata, length)
def forward(self): def forward(self):
""" """
......
...@@ -501,6 +501,25 @@ class TestNetwork(TestShuffleNet): ...@@ -501,6 +501,25 @@ class TestNetwork(TestShuffleNet):
os.remove(fast_run_cache) os.remove(fast_run_cache)
os.remove(global_layout_transform_model) os.remove(global_layout_transform_model)
def test_network_basic_mem(self):
network = LiteNetwork()
with open(self.model_path, "rb") as file:
network.load(file)
input_name = network.get_input_name(0)
input_tensor = network.get_io_tensor(input_name)
output_name = network.get_output_name(0)
output_tensor = network.get_io_tensor(output_name)
assert input_tensor.layout.shapes[0] == 1
assert input_tensor.layout.shapes[1] == 3
assert input_tensor.layout.shapes[2] == 224
assert input_tensor.layout.shapes[3] == 224
assert input_tensor.layout.data_type == LiteDataType.LITE_FLOAT
assert input_tensor.layout.ndim == 4
self.do_forward(network)
class TestDiscreteInputNet(unittest.TestCase): class TestDiscreteInputNet(unittest.TestCase):
source_dir = os.getenv("LITE_TEST_RESOURCE") source_dir = os.getenv("LITE_TEST_RESOURCE")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册