From 17bf7278856443617b815ec923a41b843de9a3be Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 27 Mar 2023 11:30:08 +0800 Subject: [PATCH] fix(lite/pylite): fix megenginelite can not load model from file object GitOrigin-RevId: b3162f7a9690ead9913542e27d69babb3bd81906 --- lite/pylite/megenginelite/network.py | 26 +++++++++++++++++++++----- lite/pylite/test/test_network.py | 19 +++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/lite/pylite/megenginelite/network.py b/lite/pylite/megenginelite/network.py index 7283355a5..2620034c0 100644 --- a/lite/pylite/megenginelite/network.py +++ b/lite/pylite/megenginelite/network.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- +import io +import os from ctypes import * -import numpy as np +import megfile from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase from .struct import * @@ -594,6 +596,7 @@ class LiteNetwork(object): c_network_io = self.network_io._create_network_io() self._api.LITE_make_network(byref(self._network), self.config, c_network_io) + self.model_bytes = None def __repr__(self): data = {"config": self.config, "IOs": self.network_io} @@ -602,12 +605,25 @@ class LiteNetwork(object): def __del__(self): self._api.LITE_destroy_network(self._network) - def load(self, path): + def load(self, file): """ - load network from given path + load network from given file or file object """ - c_path = c_char_p(path.encode("utf-8")) - self._api.LITE_load_model_from_path(self._network, c_path) + if isinstance(file, (str, os.PathLike)): + with megfile.smart_open(file, "rb") as f: + self.model_bytes = f.read() + else: + assert isinstance( + file, io.BufferedReader + ), "file must be BufferedReader when open!!" + self.model_bytes = file.read() + + self.model_bytes = io.BytesIO(self.model_bytes) + length = self.model_bytes.getbuffer().nbytes + self.model_bytes = c_buffer(self.model_bytes.getvalue()) + + cdata = cast(self.model_bytes, POINTER(c_void_p)) + self._api.LITE_load_model_from_mem(self._network, cdata, length) def forward(self): """ diff --git a/lite/pylite/test/test_network.py b/lite/pylite/test/test_network.py index 1847ff875..cfb4cfdf1 100644 --- a/lite/pylite/test/test_network.py +++ b/lite/pylite/test/test_network.py @@ -501,6 +501,25 @@ class TestNetwork(TestShuffleNet): os.remove(fast_run_cache) os.remove(global_layout_transform_model) + def test_network_basic_mem(self): + network = LiteNetwork() + with open(self.model_path, "rb") as file: + network.load(file) + + input_name = network.get_input_name(0) + input_tensor = network.get_io_tensor(input_name) + output_name = network.get_output_name(0) + output_tensor = network.get_io_tensor(output_name) + + assert input_tensor.layout.shapes[0] == 1 + assert input_tensor.layout.shapes[1] == 3 + assert input_tensor.layout.shapes[2] == 224 + assert input_tensor.layout.shapes[3] == 224 + assert input_tensor.layout.data_type == LiteDataType.LITE_FLOAT + assert input_tensor.layout.ndim == 4 + + self.do_forward(network) + class TestDiscreteInputNet(unittest.TestCase): source_dir = os.getenv("LITE_TEST_RESOURCE") -- GitLab