Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a891f9b3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a891f9b3
编写于
5月 31, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs(api/lite): add megenginelite.network api doc
GitOrigin-RevId: e0b8eb207426d0907f2dd6835d8cd00a20b8d4fa
上级
5ef1ac75
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
339 addition
and
40 deletion
+339
-40
lite/pylite/megenginelite/network.py
lite/pylite/megenginelite/network.py
+339
-40
未找到文件。
lite/pylite/megenginelite/network.py
浏览文件 @
a891f9b3
...
@@ -11,7 +11,82 @@ from .tensor import *
...
@@ -11,7 +11,82 @@ from .tensor import *
class
LiteOptions
(
Structure
):
class
LiteOptions
(
Structure
):
"""
"""
the inference options will be used to config a network
the inference options which can optimize the network forwarding
performance
Attributes:
weight_preprocess: is the option which optimize the inference performance
with processing the weights of the network ahead
fuse_preprocess: fuse preprocess patten, like astype + pad_channel +
dimshuffle
fake_next_exec: whether only to perform non-computing tasks (like
memory allocation and queue initialization) for next exec. This will be
reset to false when the graph is executed.
var_sanity_check_first_run: Disable var sanity check on the first run.
Var sanity check is enabled on the first-time execution by default, and can
be used to find some potential memory access errors in the operator
const_shape: used to reduce memory usage and improve performance since some
static inference data structures can be omitted and some operators can be
compute before forwarding
force_dynamic_alloc: force dynamic allocate memory for all vars
force_output_dynamic_alloc: force dynamic allocate memory for output tensor
which are used as the input of CallbackCaller Operator
no_profiling_on_shape_change: do not re-profile to select best implement
algo when input shape changes (use previous algo)
jit_level: Execute supported operators with JIT (support MLIR,
NVRTC). Can only be used on Nvidia GPUs and X86 CPU, this value indicates JIT level:
level 1: for JIT execute with basic elemwise operator
level 2: for JIT execute elemwise and reduce operators
record_level: flags to optimize the inference performance with record the
kernel tasks in first run, hereafter the inference all need is to execute the
recorded tasks.
level = 0 means the normal inference
level = 1 means use record inference
level = 2 means record inference with free the extra memory
graph_opt_level: network optimization level:
0: disable
1: level-1: inplace arith transformations during graph construction
2: level-2: level-1, plus global optimization before graph compiling
3: also enable JIT
async_exec_level: level of dispatch on separate threads for different comp_node.
0: do not perform async dispatch
1: dispatch async if there are more than one comp node with limited queue
mask 0b10: async if there are multiple comp nodes with
mask 0b100: always async
Examples:
.. code-block::
from megenginelite import *
options = LiteOptions()
options.weight_preprocess = true
options.record_level = 1
options.fuse_preprocess = true
"""
"""
_fields_
=
[
_fields_
=
[
...
@@ -39,6 +114,7 @@ class LiteOptions(Structure):
...
@@ -39,6 +114,7 @@ class LiteOptions(Structure):
]
]
def
__init__
(
self
):
def
__init__
(
self
):
self
.
weight_preprocess
=
False
self
.
weight_preprocess
=
False
self
.
fuse_preprocess
=
False
self
.
fuse_preprocess
=
False
self
.
fake_next_exec
=
False
self
.
fake_next_exec
=
False
...
@@ -76,17 +152,34 @@ class LiteOptions(Structure):
...
@@ -76,17 +152,34 @@ class LiteOptions(Structure):
class
LiteConfig
(
Structure
):
class
LiteConfig
(
Structure
):
"""
"""
Configuration when load and compile the graph
Configuration when load and compile a network
Attributes:
has_compression: flag whether the model is compressed, the compress
method is stored in the model
bare_model_cryption_name: is the bare model cryption method name, bare
device_id: configure the device id of a network
model is not pack model info inside
use_loader_dynamic_param: when model forward with device loader of npu,
device_type: configure the device type of a network
use_loader_dynamic_param used to flag whether the loader use device input or
output, if use device input or output it will set Non-zero , else set zero
has_compression: flag whether the model is compressed, the compress
backend: configure the inference backend of a network, now only support
method will used to read the model
megengine
bare_model_cryption_name: is the bare model encryption method name, bare
model is not packed with json information, this encryption method name is
useful to decrypt the encrypted bare model
options: configuration of Options
Examples:
.. code-block::
from megenginelite import *
config = LiteConfig()
config.has_compression = false
config.device_type = LiteDeviceType.LITE_CPU
config.backend = LiteBackend.LITE_DEFAULT
config.bare_model_cryption_name = "AES_default".encode("utf-8")
"""
"""
_fields_
=
[
_fields_
=
[
...
@@ -161,23 +254,43 @@ class LiteExtraConfig(Structure):
...
@@ -161,23 +254,43 @@ class LiteExtraConfig(Structure):
class
LiteIO
(
Structure
):
class
LiteIO
(
Structure
):
"""
"""
config the network input and output item
config the network input and output item, the input and output tensor
information will describe there
Attributes:
name: the tensor name in the graph corresponding to the IO
is_host: Used to mark where the input tensor comes from and where the output
tensor will copy to, if is_host is true, the input is from host and output copy
to host, otherwise in device. Sometimes the input is from device and output no need
copy to host, default is true.
io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
output tensor value is invaid, only shape will be set, default is VALUE
config_layout: The layout of the config from user, if other layout is set before
forward or get after forward, this layout will by pass. if no other
layout is set before forward, this layout will work. if this layout is
no set, the model will forward with its origin layout. if in output, it
will used to check.
Note:
if other layout is set to input tensor before forwarding, this layout will not work
name: the tensor name in the graph corresponding to the IO
if no layout is set before forwarding, the model will forward with its origin layout
is_host: Used to mark where the input tensor comes from and the output where copy
if layout is set in output tensor, it will used to check whether the layout computed from the network is correct
to, if is_host is true, the input is from host and output copy to host,
otherwise device. Sometimes The input is from device and output no need
copy to host, default is true.
io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
Examples:
output tensor value is invaid, only shape will be set, default is VALUE
.. code-block::
from megenginelite import *
io = LiteIO(
"data2",
is_host=True,
io_type=LiteIOType.LITE_IO_SHAPE,
layout=LiteLayout([2, 4, 4]),
)
config_layout: The layout of the config from user, if other layout is set before
forward or get after forward, this layout will by pass. if no other
layout is set before forward, this layout will work. if this layout is
no set, the model will forward with its origin layout. if in output, it
will used to check.
"""
"""
_fields_
=
[
_fields_
=
[
...
@@ -205,10 +318,16 @@ class LiteIO(Structure):
...
@@ -205,10 +318,16 @@ class LiteIO(Structure):
@
property
@
property
def
name
(
self
):
def
name
(
self
):
"""
get the name of IO item
"""
return
self
.
_name
.
decode
(
"utf-8"
)
return
self
.
_name
.
decode
(
"utf-8"
)
@
name
.
setter
@
name
.
setter
def
name
(
self
,
name
):
def
name
(
self
,
name
):
"""
set the name of IO item
"""
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
self
.
_name
=
name
.
encode
(
"utf-8"
)
self
.
_name
=
name
.
encode
(
"utf-8"
)
else
:
else
:
...
@@ -229,9 +348,6 @@ class LiteIO(Structure):
...
@@ -229,9 +348,6 @@ class LiteIO(Structure):
class
_LiteNetworkIO
(
Structure
):
class
_LiteNetworkIO
(
Structure
):
"""
the input and output information when load the network
"""
_fields_
=
[
_fields_
=
[
(
"inputs"
,
POINTER
(
LiteIO
)),
(
"inputs"
,
POINTER
(
LiteIO
)),
...
@@ -249,7 +365,24 @@ class _LiteNetworkIO(Structure):
...
@@ -249,7 +365,24 @@ class _LiteNetworkIO(Structure):
class
LiteNetworkIO
(
object
):
class
LiteNetworkIO
(
object
):
"""
"""
the input and output information for user to construct _LiteNetWorkIO
the input and output information when load the network for user
the NetworkIO will remain in the network until the network is destroyed.
Attributes:
inputs: The all input tensors information that will configure to the network
outputs: The all output tensors information that will configure to the network
Examples:
.. code-block::
from megenginelite import *
input_io = LiteIO("data", is_host=False, io_type=LiteIOType.LITE_IO_VALUE)
io = LiteNetworkIO()
io.add_input(input_io)
output_io = LiteIO("out", is_host=True, layout=LiteLayout([1, 1000]))
io.add_output(output_io)
"""
"""
def
__init__
(
self
,
inputs
=
None
,
outputs
=
None
):
def
__init__
(
self
,
inputs
=
None
,
outputs
=
None
):
...
@@ -277,6 +410,9 @@ class LiteNetworkIO(object):
...
@@ -277,6 +410,9 @@ class LiteNetworkIO(object):
def
add_input
(
def
add_input
(
self
,
obj
,
is_host
=
True
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
,
layout
=
None
self
,
obj
,
is_host
=
True
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
,
layout
=
None
):
):
"""
add input information into LiteNetworkIO
"""
if
isinstance
(
obj
,
LiteIO
):
if
isinstance
(
obj
,
LiteIO
):
self
.
inputs
.
append
(
obj
)
self
.
inputs
.
append
(
obj
)
else
:
else
:
...
@@ -286,6 +422,9 @@ class LiteNetworkIO(object):
...
@@ -286,6 +422,9 @@ class LiteNetworkIO(object):
def
add_output
(
def
add_output
(
self
,
obj
,
is_host
=
True
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
,
layout
=
None
self
,
obj
,
is_host
=
True
,
io_type
=
LiteIOType
.
LITE_IO_VALUE
,
layout
=
None
):
):
"""
add output information into LiteNetworkIO
"""
if
isinstance
(
obj
,
LiteIO
):
if
isinstance
(
obj
,
LiteIO
):
self
.
outputs
.
append
(
obj
)
self
.
outputs
.
append
(
obj
)
else
:
else
:
...
@@ -397,6 +536,27 @@ class _NetworkAPI(_LiteCObjBase):
...
@@ -397,6 +536,27 @@ class _NetworkAPI(_LiteCObjBase):
class
LiteNetwork
(
object
):
class
LiteNetwork
(
object
):
"""
"""
the network to load a model and forward
the network to load a model and forward
Examples:
.. code-block::
from megenginelite import *
config = LiteConfig()
config.device_type = LiteDeviceType.LITE_CPU
network = LiteNetwork(config)
network.load("model_path")
input_name = network.get_input_name(0)
input_tensor = network.get_io_tensor(input_name)
output_name = network.get_output_name(0)
output_tensor = network.get_io_tensor(output_name)
input_tensor.set_data_by_copy(input_data)
network.forward()
network.wait()
"""
"""
_api
=
_NetworkAPI
().
_lib
_api
=
_NetworkAPI
().
_lib
...
@@ -428,18 +588,33 @@ class LiteNetwork(object):
...
@@ -428,18 +588,33 @@ class LiteNetwork(object):
self
.
_api
.
LITE_destroy_network
(
self
.
_network
)
self
.
_api
.
LITE_destroy_network
(
self
.
_network
)
def
load
(
self
,
path
):
def
load
(
self
,
path
):
"""
load network from given path
"""
c_path
=
c_char_p
(
path
.
encode
(
"utf-8"
))
c_path
=
c_char_p
(
path
.
encode
(
"utf-8"
))
self
.
_api
.
LITE_load_model_from_path
(
self
.
_network
,
c_path
)
self
.
_api
.
LITE_load_model_from_path
(
self
.
_network
,
c_path
)
def
forward
(
self
):
def
forward
(
self
):
"""
forward the network with filled input data and fill the output data
to the output tensor
"""
self
.
_api
.
LITE_forward
(
self
.
_network
)
self
.
_api
.
LITE_forward
(
self
.
_network
)
def
wait
(
self
):
def
wait
(
self
):
"""
wait until forward finish in sync model
"""
self
.
_api
.
LITE_wait
(
self
.
_network
)
self
.
_api
.
LITE_wait
(
self
.
_network
)
def
is_cpu_inplace_mode
(
self
):
def
is_cpu_inplace_mode
(
self
):
"""
"""
whether the network run in cpu inpalce mode
whether the network run in cpu inpalce mode
Returns:
if use inpalce mode return True, else return False
"""
"""
inplace
=
c_int
()
inplace
=
c_int
()
self
.
_api
.
LITE_is_cpu_inplace_mode
(
self
.
_network
,
byref
(
inplace
))
self
.
_api
.
LITE_is_cpu_inplace_mode
(
self
.
_network
,
byref
(
inplace
))
...
@@ -449,13 +624,20 @@ class LiteNetwork(object):
...
@@ -449,13 +624,20 @@ class LiteNetwork(object):
"""
"""
set cpu forward in inplace mode with which cpu forward only create one
set cpu forward in inplace mode with which cpu forward only create one
thread
thread
Note: this must be set before the network loaded
Note:
this must be set before the network loaded
"""
"""
self
.
_api
.
LITE_set_cpu_inplace_mode
(
self
.
_network
)
self
.
_api
.
LITE_set_cpu_inplace_mode
(
self
.
_network
)
def
use_tensorrt
(
self
):
def
use_tensorrt
(
self
):
"""
"""
Note: this must be set before the network loaded
use TensorRT
Note:
this must be set before the network loaded
"""
"""
self
.
_api
.
LITE_use_tensorrt
(
self
.
_network
)
self
.
_api
.
LITE_use_tensorrt
(
self
.
_network
)
...
@@ -463,6 +645,9 @@ class LiteNetwork(object):
...
@@ -463,6 +645,9 @@ class LiteNetwork(object):
def
device_id
(
self
):
def
device_id
(
self
):
"""
"""
get the device id
get the device id
Returns:
the device id of current network used
"""
"""
device_id
=
c_int
()
device_id
=
c_int
()
self
.
_api
.
LITE_get_device_id
(
self
.
_network
,
byref
(
device_id
))
self
.
_api
.
LITE_get_device_id
(
self
.
_network
,
byref
(
device_id
))
...
@@ -472,7 +657,10 @@ class LiteNetwork(object):
...
@@ -472,7 +657,10 @@ class LiteNetwork(object):
def
device_id
(
self
,
device_id
):
def
device_id
(
self
,
device_id
):
"""
"""
set the device id
set the device id
Note: this must be set before the network loaded
Note:
this must be set before the network loaded
"""
"""
self
.
_api
.
LITE_set_device_id
(
self
.
_network
,
device_id
)
self
.
_api
.
LITE_set_device_id
(
self
.
_network
,
device_id
)
...
@@ -480,6 +668,9 @@ class LiteNetwork(object):
...
@@ -480,6 +668,9 @@ class LiteNetwork(object):
def
stream_id
(
self
):
def
stream_id
(
self
):
"""
"""
get the stream id
get the stream id
Returns:
the value of stream id set for detwork
"""
"""
stream_id
=
c_int
()
stream_id
=
c_int
()
self
.
_api
.
LITE_get_stream_id
(
self
.
_network
,
byref
(
stream_id
))
self
.
_api
.
LITE_get_stream_id
(
self
.
_network
,
byref
(
stream_id
))
...
@@ -489,7 +680,9 @@ class LiteNetwork(object):
...
@@ -489,7 +680,9 @@ class LiteNetwork(object):
def
stream_id
(
self
,
stream_id
):
def
stream_id
(
self
,
stream_id
):
"""
"""
set the stream id
set the stream id
Note: this must be set before the network loaded
Note:
this must be set before the network loaded
"""
"""
self
.
_api
.
LITE_set_stream_id
(
self
.
_network
,
stream_id
)
self
.
_api
.
LITE_set_stream_id
(
self
.
_network
,
stream_id
)
...
@@ -497,6 +690,9 @@ class LiteNetwork(object):
...
@@ -497,6 +690,9 @@ class LiteNetwork(object):
def
threads_number
(
self
):
def
threads_number
(
self
):
"""
"""
get the thread number of the netwrok
get the thread number of the netwrok
Returns:
the number of thread set in the network
"""
"""
nr_thread
=
c_size_t
()
nr_thread
=
c_size_t
()
self
.
_api
.
LITE_get_cpu_threads_number
(
self
.
_network
,
byref
(
nr_thread
))
self
.
_api
.
LITE_get_cpu_threads_number
(
self
.
_network
,
byref
(
nr_thread
))
...
@@ -506,13 +702,22 @@ class LiteNetwork(object):
...
@@ -506,13 +702,22 @@ class LiteNetwork(object):
def
threads_number
(
self
,
nr_threads
):
def
threads_number
(
self
,
nr_threads
):
"""
"""
set the network forward in multithread mode, and the thread number
set the network forward in multithread mode, and the thread number
Note: this must be set before the network loaded
Note:
this must be set before the network loaded
"""
"""
self
.
_api
.
LITE_set_cpu_threads_number
(
self
.
_network
,
nr_threads
)
self
.
_api
.
LITE_set_cpu_threads_number
(
self
.
_network
,
nr_threads
)
def
get_io_tensor
(
self
,
name
,
phase
=
LiteTensorPhase
.
LITE_IO
):
def
get_io_tensor
(
self
,
name
,
phase
=
LiteTensorPhase
.
LITE_IO
):
"""
"""
get input or output tensor by its name
get input or output tensor by its name
Args:
name: the name of io tensor
phase: the type of LiteTensor, this is useful to separate input or output tensor with the same name
Returns:
the tensor with given name and type
"""
"""
if
type
(
name
)
==
str
:
if
type
(
name
)
==
str
:
c_name
=
c_char_p
(
name
.
encode
(
"utf-8"
))
c_name
=
c_char_p
(
name
.
encode
(
"utf-8"
))
...
@@ -528,6 +733,12 @@ class LiteNetwork(object):
...
@@ -528,6 +733,12 @@ class LiteNetwork(object):
def
get_input_name
(
self
,
index
):
def
get_input_name
(
self
,
index
):
"""
"""
get the input name by the index in the network
get the input name by the index in the network
Args:
index: the index of the input name
Returns:
the name of input tesor with given index
"""
"""
c_name
=
c_char_p
()
c_name
=
c_char_p
()
self
.
_api
.
LITE_get_input_name
(
self
.
_network
,
index
,
byref
(
c_name
))
self
.
_api
.
LITE_get_input_name
(
self
.
_network
,
index
,
byref
(
c_name
))
...
@@ -536,6 +747,12 @@ class LiteNetwork(object):
...
@@ -536,6 +747,12 @@ class LiteNetwork(object):
def
get_output_name
(
self
,
index
):
def
get_output_name
(
self
,
index
):
"""
"""
get the output name by the index in the network
get the output name by the index in the network
Args:
index: the index of the output name
Returns:
the name of output tesor with given index
"""
"""
c_name
=
c_char_p
()
c_name
=
c_char_p
()
self
.
_api
.
LITE_get_output_name
(
self
.
_network
,
index
,
byref
(
c_name
))
self
.
_api
.
LITE_get_output_name
(
self
.
_network
,
index
,
byref
(
c_name
))
...
@@ -544,6 +761,9 @@ class LiteNetwork(object):
...
@@ -544,6 +761,9 @@ class LiteNetwork(object):
def
get_all_input_name
(
self
):
def
get_all_input_name
(
self
):
"""
"""
get all the input tensor name in the network
get all the input tensor name in the network
Returns:
the names of all input tesor in the network
"""
"""
nr_input
=
c_size_t
()
nr_input
=
c_size_t
()
self
.
_api
.
LITE_get_all_input_name
(
self
.
_network
,
byref
(
nr_input
),
None
)
self
.
_api
.
LITE_get_all_input_name
(
self
.
_network
,
byref
(
nr_input
),
None
)
...
@@ -557,6 +777,9 @@ class LiteNetwork(object):
...
@@ -557,6 +777,9 @@ class LiteNetwork(object):
def
get_all_output_name
(
self
):
def
get_all_output_name
(
self
):
"""
"""
get all the output tensor name in the network
get all the output tensor name in the network
Returns:
the names of all output tesor in the network
"""
"""
nr_output
=
c_size_t
()
nr_output
=
c_size_t
()
self
.
_api
.
LITE_get_all_output_name
(
self
.
_network
,
byref
(
nr_output
),
None
)
self
.
_api
.
LITE_get_all_output_name
(
self
.
_network
,
byref
(
nr_output
),
None
)
...
@@ -576,6 +799,9 @@ class LiteNetwork(object):
...
@@ -576,6 +799,9 @@ class LiteNetwork(object):
def
share_weights_with
(
self
,
src_network
):
def
share_weights_with
(
self
,
src_network
):
"""
"""
share weights with the loaded network
share weights with the loaded network
Args:
src_network: the network to share weights
"""
"""
assert
isinstance
(
src_network
,
LiteNetwork
)
assert
isinstance
(
src_network
,
LiteNetwork
)
self
.
_api
.
LITE_shared_weight_with_network
(
self
.
_network
,
src_network
.
_network
)
self
.
_api
.
LITE_shared_weight_with_network
(
self
.
_network
,
src_network
.
_network
)
...
@@ -583,11 +809,21 @@ class LiteNetwork(object):
...
@@ -583,11 +809,21 @@ class LiteNetwork(object):
def
share_runtime_memroy
(
self
,
src_network
):
def
share_runtime_memroy
(
self
,
src_network
):
"""
"""
share runtime memory with the srouce network
share runtime memory with the srouce network
Args:
src_network: the network to share runtime memory
"""
"""
assert
isinstance
(
src_network
,
LiteNetwork
)
assert
isinstance
(
src_network
,
LiteNetwork
)
self
.
_api
.
LITE_share_runtime_memroy
(
self
.
_network
,
src_network
.
_network
)
self
.
_api
.
LITE_share_runtime_memroy
(
self
.
_network
,
src_network
.
_network
)
def
async_with_callback
(
self
,
async_callback
):
def
async_with_callback
(
self
,
async_callback
):
"""
set the network forwarding in async mode and set the AsyncCallback callback
function
Args:
async_callback: the callback to set for network
"""
callback
=
wrap_async_callback
(
async_callback
)
callback
=
wrap_async_callback
(
async_callback
)
self
.
_api
.
LITE_set_async_callback
(
self
.
_network
,
callback
)
self
.
_api
.
LITE_set_async_callback
(
self
.
_network
,
callback
)
...
@@ -596,6 +832,9 @@ class LiteNetwork(object):
...
@@ -596,6 +832,9 @@ class LiteNetwork(object):
when the network start forward, the callback will be called,
when the network start forward, the callback will be called,
the start_callback with param mapping from LiteIO to the corresponding
the start_callback with param mapping from LiteIO to the corresponding
LiteTensor
LiteTensor
Args:
start_callback: the callback to set for network
"""
"""
callback
=
start_finish_callback
(
start_callback
)
callback
=
start_finish_callback
(
start_callback
)
self
.
_api
.
LITE_set_start_callback
(
self
.
_network
,
callback
)
self
.
_api
.
LITE_set_start_callback
(
self
.
_network
,
callback
)
...
@@ -605,28 +844,49 @@ class LiteNetwork(object):
...
@@ -605,28 +844,49 @@ class LiteNetwork(object):
when the network finish forward, the callback will be called,
when the network finish forward, the callback will be called,
the finish_callback with param mapping from LiteIO to the corresponding
the finish_callback with param mapping from LiteIO to the corresponding
LiteTensor
LiteTensor
Args:
finish_callback: the callback to set for network
"""
"""
callback
=
start_finish_callback
(
finish_callback
)
callback
=
start_finish_callback
(
finish_callback
)
self
.
_api
.
LITE_set_finish_callback
(
self
.
_network
,
callback
)
self
.
_api
.
LITE_set_finish_callback
(
self
.
_network
,
callback
)
def
enable_profile_performance
(
self
,
profile_file
):
def
enable_profile_performance
(
self
,
profile_file
):
"""
enable get the network performance profiled information and save into given file
Args:
profile_file: the file to save profile information
"""
c_file
=
profile_file
.
encode
(
"utf-8"
)
c_file
=
profile_file
.
encode
(
"utf-8"
)
self
.
_api
.
LITE_enable_profile_performance
(
self
.
_network
,
c_file
)
self
.
_api
.
LITE_enable_profile_performance
(
self
.
_network
,
c_file
)
def
set_network_algo_workspace_limit
(
self
,
size_limit
):
def
set_network_algo_workspace_limit
(
self
,
size_limit
):
"""
set the opr workspace limitation in the target network, some opr
maybe use large of workspace to get good performance, set workspace limitation
can save memory but may influence the performance
Args:
size_limit: the byte size of workspace limitation
"""
self
.
_api
.
LITE_set_network_algo_workspace_limit
(
self
.
_network
,
size_limit
)
self
.
_api
.
LITE_set_network_algo_workspace_limit
(
self
.
_network
,
size_limit
)
def
set_network_algo_policy
(
def
set_network_algo_policy
(
self
,
policy
,
shared_batch_size
=
0
,
binary_equal_between_batch
=
False
self
,
policy
,
shared_batch_size
=
0
,
binary_equal_between_batch
=
False
):
):
"""
"""
shared_batch_size: the batch size used by fastrun,
set the network algorithm search policy for fast-run
Non-zero value means that fastrun use this batch size
regardless of the batch size of the model. Zero means
Args:
fastrun use batch size of the model
shared_batch_size: the batch size used by fastrun,
binary_equal_between_batch: if the content of each input batch is
Non-zero value means that fastrun use this batch size
binary equal,whether the content of each output batch is
regardless of the batch size of the model. Zero means
promised to be equal
fastrun use batch size of the model
binary_equal_between_batch: if the content of each input batch is
binary equal,whether the content of each output batch is
promised to be equal
"""
"""
self
.
_api
.
LITE_set_network_algo_policy
(
self
.
_network
,
policy
)
self
.
_api
.
LITE_set_network_algo_policy
(
self
.
_network
,
policy
)
...
@@ -635,29 +895,68 @@ class LiteNetwork(object):
...
@@ -635,29 +895,68 @@ class LiteNetwork(object):
)
)
def
io_txt_dump
(
self
,
txt_file
):
def
io_txt_dump
(
self
,
txt_file
):
"""
dump all input/output tensor of all operators to the output file, in txt
format, user can use this function to debug compute error
Args:
txt_file: the txt file
"""
c_file
=
txt_file
.
encode
(
"utf-8"
)
c_file
=
txt_file
.
encode
(
"utf-8"
)
self
.
_api
.
LITE_enable_io_txt_dump
(
self
.
_network
,
c_file
)
self
.
_api
.
LITE_enable_io_txt_dump
(
self
.
_network
,
c_file
)
def
io_bin_dump
(
self
,
bin_dir
):
def
io_bin_dump
(
self
,
bin_dir
):
"""
dump all input/output tensor of all operators to the output file, in
binary format, user can use this function to debug compute error
Args:
bin_dir: the binary file directory
"""
c_dir
=
bin_dir
.
encode
(
"utf-8"
)
c_dir
=
bin_dir
.
encode
(
"utf-8"
)
self
.
_api
.
LITE_enable_io_bin_dump
(
self
.
_network
,
c_dir
)
self
.
_api
.
LITE_enable_io_bin_dump
(
self
.
_network
,
c_dir
)
def
get_static_memory_alloc_info
(
self
,
log_dir
=
"logs/test"
):
def
get_static_memory_alloc_info
(
self
,
log_dir
=
"logs/test"
):
"""
get static peak memory info showed by Graph visualization
Args:
log_dir: the directory to save information log
"""
c_log_dir
=
log_dir
.
encode
(
"utf-8"
)
c_log_dir
=
log_dir
.
encode
(
"utf-8"
)
self
.
_api
.
LITE_get_static_memory_alloc_info
(
self
.
_network
,
c_log_dir
)
self
.
_api
.
LITE_get_static_memory_alloc_info
(
self
.
_network
,
c_log_dir
)
def
enable_global_layout_transform
(
self
):
def
enable_global_layout_transform
(
self
):
"""
set global layout transform optimization for network, global
layout optimization can auto determine the layout of every operator in
the network by profile, thus it can improve the performance of the
network forwarding
"""
self
.
_api
.
LITE_enable_global_layout_transform
(
self
.
_network
)
self
.
_api
.
LITE_enable_global_layout_transform
(
self
.
_network
)
def
dump_layout_transform_model
(
self
,
model_file
):
def
dump_layout_transform_model
(
self
,
model_file
):
"""
dump network after global layout transform optimization to the
specific path
Args:
model_file: the file path to dump model
"""
c_file
=
model_file
.
encode
(
"utf-8"
)
c_file
=
model_file
.
encode
(
"utf-8"
)
self
.
_api
.
LITE_dump_layout_transform_model
(
self
.
_network
,
c_file
)
self
.
_api
.
LITE_dump_layout_transform_model
(
self
.
_network
,
c_file
)
def
get_model_io_info
(
model_path
,
config
=
None
):
def
get_model_io_info
(
model_path
,
config
=
None
):
"""
"""
get the model IO information before create the NetWork, this IO
get the model io information before model loaded by model path.
information can be used to configuration the NetWork.
Args:
model_path: the model path to get the model IO information
config the model configuration
Returns:
the input and output information in the network configuration
"""
"""
api
=
_NetworkAPI
().
_lib
api
=
_NetworkAPI
().
_lib
c_path
=
c_char_p
(
model_path
.
encode
(
"utf-8"
))
c_path
=
c_char_p
(
model_path
.
encode
(
"utf-8"
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录