Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5ef1ac75
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5ef1ac75
编写于
5月 26, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs(api/lite): add lite network api doc
GitOrigin-RevId: 5d416cc5af9595240dbf71bdc819d989ec2d5dbc
上级
c47f48ef
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
256 addition
and
111 deletion
+256
-111
lite/include/lite/network.h
lite/include/lite/network.h
+256
-111
未找到文件。
lite/include/lite/network.h
浏览文件 @
5ef1ac75
...
@@ -18,56 +18,56 @@ LITE_API inline LiteAlgoSelectStrategy operator|(
...
@@ -18,56 +18,56 @@ LITE_API inline LiteAlgoSelectStrategy operator|(
}
}
/*!
/*!
* \brief the inference options which will be translated to megenine
* @brief the inference options which can optimize the network forwarding
* performance
*
*
*
\param weight_preprocess is the option wich optimize the infere
ce performance
*
@param weight_preprocess is the option which optimize the inferen
ce performance
* with pr
eprocess the const weights
* with pr
ocessing the weights of the network ahead
*
*
*
\
param fuse_preprocess fuse preprocess patten, like astype + pad_channel +
*
@
param fuse_preprocess fuse preprocess patten, like astype + pad_channel +
* dimshuffle
* dimshuffle
*
*
*
\
param fake_next_exec whether only to perform non-computing tasks (like
*
@
param fake_next_exec whether only to perform non-computing tasks (like
* memory allocation and queue initialization) for next exec. This w
ould
be
* memory allocation and queue initialization) for next exec. This w
ill
be
* reset to false when the graph is executed.
* reset to false when the graph is executed.
*
*
*
\
param var_sanity_check_first_run Disable var sanity check on the first run.
*
@
param var_sanity_check_first_run Disable var sanity check on the first run.
* Var sanity check is enabled on the first-time execution by default, and can
* Var sanity check is enabled on the first-time execution by default, and can
* be used to find some potential memory access errors in the operator
* be used to find some potential memory access errors in the operator
* implementation.
*
*
* \param const_shape This can be used to reduce memory usage since some
* @param const_shape used to reduce memory usage and improve performance since some
* static inference data structures can be omitted.
* static inference data structures can be omitted and some operators can be
* compute before forwarding
*
*
*
\param force_dynamic_alloc force dynamic memory alloc
for all vars
*
@param force_dynamic_alloc force dynamic allocate memory
for all vars
*
*
*
\param force_output_dynamic_alloc force dynamic memory alloc for output vars
*
@param force_output_dynamic_alloc force dynamic allocate memory for output tensor
* which are used as
CallbackCaller input when call compile() function
* which are used as
the input of CallbackCaller Operator
*
*
*
\param no_profiling_on_shape_change do not re-profile to select best impl
*
@param no_profiling_on_shape_change do not re-profile to select best implement
* algo when input shape changes (use previous algo)
* algo when input shape changes (use previous algo)
*
*
*
\
param jit_level Execute supported operators with JIT (support MLIR,
*
@
param jit_level Execute supported operators with JIT (support MLIR,
* NVRTC). Can only be used on Nvidia GPUs, this value indicates JIT level:
* NVRTC). Can only be used on Nvidia GPUs
and X86 CPU
, this value indicates JIT level:
*
1 for basic elemwise opr;
*
level 1: for JIT execute with basic elemwise operator
*
2 for including reduce operator
*
level 2: for JIT execute elemwise and reduce operators
*
*
*
\param record_level flag optimize the inference performa
ce with record the
*
@param record_level flags to optimize the inference performan
ce with record the
* kernel tasks in first run, hereafter the inference all need to execute the
* kernel tasks in first run, hereafter the inference all need
is
to execute the
* recorded tasks.
* recorded tasks.
* level = 0 means the normal inference,
* level = 0 means the normal inference,
* level = 1 means use record inference,
* level = 1 means use record inference,
* level = 2 means record inference with free the extra memory
* level = 2 means record inference with free the extra memory
*
*
*
\param graph_opt_level
optimization level:
*
@param graph_opt_level network
optimization level:
* 0: disable
* 0: disable
* 1: level-1: inplace arith transformations during graph
* 1: level-1: inplace arith transformations during graph
* construction
* construction
* 2: level-2: level-1, plus global optimization before graph
* 2: level-2: level-1, plus global optimization before graph
* compiling
* compiling
* 3: also enable JIT
* 3: also enable JIT
* <0: corresponding level, with result check for debug
*
*
*
\param async_exec_level exec:
dispatch on separate threads for different
*
@param async_exec_level level of
dispatch on separate threads for different
* comp_node.
* comp_node.
* 0: do not perform async dispatch
* 0: do not perform async dispatch
* 1: dispatch async if there are more than one comp node with limited queue
* 1: dispatch async if there are more than one comp node with limited queue
...
@@ -99,14 +99,21 @@ struct LITE_API Options {
...
@@ -99,14 +99,21 @@ struct LITE_API Options {
bool
enable_nchw64
=
false
;
bool
enable_nchw64
=
false
;
};
};
/*!
/**
* \brief Configuration when load and compile the graph
* @brief Configuration when load and compile a network
*
* @param has_compression flag whether the model is compressed, the compress
* method is stored in the model
*
* @param device_id configure the device id of a network
* @param device_type configure the device type of a network
* @param backend configure the inference backend of a network, now only support
* megengine
*
*
*
\param bare_model_cryption_name is the bare model
cryption method name, bare
*
@param bare_model_cryption_name is the bare model en
cryption method name, bare
*
model is not pack json info
inside
*
model is not pack json information data
inside
*
*
*\param has_compression flag whether the model is compressed, the compress
* @param options configuration of Options
*method will read form the model
*/
*/
struct
LITE_API
Config
{
struct
LITE_API
Config
{
bool
has_compression
=
false
;
bool
has_compression
=
false
;
...
@@ -118,9 +125,9 @@ struct LITE_API Config {
...
@@ -118,9 +125,9 @@ struct LITE_API Config {
};
};
/*!
/*!
*
\
brief Extra Configuration for a network
*
@
brief Extra Configuration for a network
*
*
*
\
param disable_configure_by_model_info disable the configuration dumped with model,
*
@
param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* if set true, all configuration in the model will not apply, users should configure
* the network.
* the network.
*/
*/
...
@@ -128,90 +135,136 @@ struct LITE_API ExtraConfig {
...
@@ -128,90 +135,136 @@ struct LITE_API ExtraConfig {
bool
disable_configure_by_model_info
=
false
;
bool
disable_configure_by_model_info
=
false
;
};
};
/*!
* \brief config the network input and output item
/**
* @brief config the network input and output item, the input and output tensor
* information will describe there
*
* @param name the input/output tensor name
*
* @param is_host Used to mark where the input tensor comes from and where the output
* tensor will copy to, if is_host is true, the input is from host and output copy
* to host, otherwise in device. Sometimes the input is from device and output no need
* copy to host, default is true.
*
* @param io_type The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
* output tensor value is invaid, only shape will be set, default is VALUE
*
* @param config_layout The layout of input or output tensor
*
*
* \verbatim embed:rst:leading-asterisk
*
* .. note::
*
* * if other layout is set to input tensor before forwarding, this layout will not
* work
* * if no layout is set before forwarding, the model will forward with its origin
* layout
* * if layout is set in output tensor, it will used to check whether the
* layout computed from the network is correct
*
* \endverbatim
*/
*/
struct
LITE_API
IO
{
struct
LITE_API
IO
{
//! the tensor name in the graph corresponding to the IO
std
::
string
name
;
std
::
string
name
;
//! Used to mark where the input tensor comes from and the output where copy
//! to, if is_host is true, the input is from host and output copy to host,
//! otherwise device. Sometimes The input is from device and output no need
//! copy to host, default is true.
bool
is_host
=
true
;
bool
is_host
=
true
;
//! The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
//! output tensor value is invaid, only shape will be set, default is VALUE
LiteIOType
io_type
=
LiteIOType
::
LITE_IO_VALUE
;
LiteIOType
io_type
=
LiteIOType
::
LITE_IO_VALUE
;
//! The layout of the config from user, if other layout is set before
//! forward or get after forward by input tensor reset, this layout will by
//! pass. if no other layout is set before forward, this layout will work.
//! if this layout is no set, the model will forward with its origin layout.
//! if in output, it will used to check.
Layout
config_layout
=
{};
Layout
config_layout
=
{};
};
};
/*!
/**
* \brief the input and output information when load the network
* @brief the input and output information when load the network
* the NetworkIO will remain in the network until the network is destroyed
* the NetworkIO will remain in the network until the network is destroyed.
*
* @param inputs The all input tensors information that will configure to the network
* @param outputs The all output tensors information that will configure to the network
*/
*/
struct
LITE_API
NetworkIO
{
struct
LITE_API
NetworkIO
{
std
::
vector
<
IO
>
inputs
=
{};
std
::
vector
<
IO
>
inputs
=
{};
std
::
vector
<
IO
>
outputs
=
{};
std
::
vector
<
IO
>
outputs
=
{};
};
};
/*!
/**
* \brief A user-implemented allocator interface
* @brief A user-implemented allocator interface, user can register an allocator
* to the megengine, then all the runtime memory will allocate by this allocator
*/
*/
class
LITE_API
Allocator
{
class
LITE_API
Allocator
{
public:
public:
virtual
~
Allocator
()
=
default
;
virtual
~
Allocator
()
=
default
;
//! allocate memory of size in the given device with the given align
/** @brief allocate memory of size in the given device with the given align
*
* @param device_type the device type the memory will allocate from
* @param device_id the device id the memory will allocate from
* @param size the byte size of memory will be allocated
* @param align the align size require when allocate the memory
*/
virtual
void
*
allocate
(
virtual
void
*
allocate
(
LiteDeviceType
device_type
,
int
device_id
,
size_t
size
,
size_t
align
)
=
0
;
LiteDeviceType
device_type
,
int
device_id
,
size_t
size
,
size_t
align
)
=
0
;
//! free the memory pointed by ptr in the given device
/** @brief free the memory pointed by ptr in the given device
*
* @param device_type the device type the memory will allocate from
* @param device_id the device id the memory will allocate from
* @param ptr the memory pointer to be free
*/
virtual
void
free
(
LiteDeviceType
device_type
,
int
device_id
,
void
*
ptr
)
=
0
;
virtual
void
free
(
LiteDeviceType
device_type
,
int
device_id
,
void
*
ptr
)
=
0
;
};
};
/*!
/**
* \brief the thread affinith callback type
* @brief the thread affinith callback function type
* \param thread_id thread_id is the a number begin from 0 to (nr_threads - 1),
*
* thread_id of (nr_threads - 1) is the main worker thread.
* @param thread_id the id of the current thread, the id is a number begin from 0 to
* (nr_threads - 1), thread id of (nr_threads - 1) is the main worker thread.
*/
*/
using
ThreadAffinityCallback
=
std
::
function
<
void
(
int
thread_id
)
>
;
using
ThreadAffinityCallback
=
std
::
function
<
void
(
int
thread_id
)
>
;
/**
* @brief the network async callback function type
*/
using
AsyncCallback
=
std
::
function
<
void
(
void
)
>
;
using
AsyncCallback
=
std
::
function
<
void
(
void
)
>
;
/*!
/**
* \brief the start/finish callback function
* @brief the start/finish callback function type
* \param unordered_map map from the io tensor name to the pair of which is the
*
* corresponding IO of user config and the realy input or output tensor.
* @param unordered_map map from the io tensor name to the pair of the
* user configuration information and the really input or output tensor.
*/
*/
//@{
using
StartCallback
=
using
StartCallback
=
std
::
function
<
void
(
const
std
::
unordered_map
<
std
::
function
<
void
(
const
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
IO
,
std
::
shared_ptr
<
Tensor
>>>&
)
>
;
std
::
string
,
std
::
pair
<
IO
,
std
::
shared_ptr
<
Tensor
>>>&
)
>
;
using
FinishCallback
=
using
FinishCallback
=
std
::
function
<
void
(
const
std
::
unordered_map
<
std
::
function
<
void
(
const
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
IO
,
std
::
shared_ptr
<
Tensor
>>>&
)
>
;
std
::
string
,
std
::
pair
<
IO
,
std
::
shared_ptr
<
Tensor
>>>&
)
>
;
//@}
/*
!
/*
*
*
\brief The network is construct form a model, implement model load, init,
*
@brief The network is the main class to perform forwarding, which is construct form a
* forward, and display some model information
*
model, and implement model load, init,
forward, and display some model information
*/
*/
class
LITE_API
Network
{
class
LITE_API
Network
{
public:
public:
class
NetworkImplBase
;
class
NetworkImplBase
;
friend
class
NetworkHelper
;
~
Network
();
~
Network
();
/*! @brief Construct a network with given configuration and IO information
*
* @name Constructor
*
* @param config The configuration to create the network
* @param networkio The NetworkIO to describe the input and output
* tensor of the network
*/
//@{
Network
(
const
Config
&
config
=
{},
const
NetworkIO
&
networkio
=
{});
Network
(
const
Config
&
config
=
{},
const
NetworkIO
&
networkio
=
{});
Network
(
const
NetworkIO
&
networkio
,
const
Config
&
config
=
{});
Network
(
const
NetworkIO
&
networkio
,
const
Config
&
config
=
{});
//@}
//! load the model form memory
//! load the model form memory
void
load_model
(
void
*
model_mem
,
size_t
size
);
void
load_model
(
void
*
model_mem
,
size_t
size
);
...
@@ -219,32 +272,37 @@ public:
...
@@ -219,32 +272,37 @@ public:
//! load the model from a model path
//! load the model from a model path
void
load_model
(
std
::
string
model_path
);
void
load_model
(
std
::
string
model_path
);
//! only compute the output tensor
in user configured
//! only compute the output tensor
configured by the IO information
void
compute_only_configured_output
();
void
compute_only_configured_output
();
//! get the network input and output tensor, the layout of which is
/** @brief get the network input and output tensor, the layout of which is
//! sync from mge tensor, when the name of input and output tensor are the
* sync from megengine tensor, when the name of input and output tensor are the
//! same, use LiteTensorPhase to separate
* same, use LiteTensorPhase to separate them
*
* @param io_name the name of the tensor
* @param phase indicate whether the tensor is input tensor or output tensor,
* maybe the input tensor name is the same with the output tensor name
*/
std
::
shared_ptr
<
Tensor
>
get_io_tensor
(
std
::
shared_ptr
<
Tensor
>
get_io_tensor
(
std
::
string
io_name
,
LiteTensorPhase
phase
=
LiteTensorPhase
::
LITE_IO
);
std
::
string
io_name
,
LiteTensorPhase
phase
=
LiteTensorPhase
::
LITE_IO
);
//! get the network input by index
//! get the network input
tensor
by index
std
::
shared_ptr
<
Tensor
>
get_input_tensor
(
size_t
index
);
std
::
shared_ptr
<
Tensor
>
get_input_tensor
(
size_t
index
);
//! get the network output tensor by index
//! get the network output tensor by index
std
::
shared_ptr
<
Tensor
>
get_output_tensor
(
size_t
index
);
std
::
shared_ptr
<
Tensor
>
get_output_tensor
(
size_t
index
);
//! set the network forward
in async mode and set the async
callback
//! set the network forward
ing in async mode and set the AsyncCallback
callback
//! function
//! function
Network
&
set_async_callback
(
const
AsyncCallback
&
async_callback
);
Network
&
set_async_callback
(
const
AsyncCallback
&
async_callback
);
//! set the start forward
callback function, which will be execute befor
e
//! set the start forward
ing callback function of type StartCallback, which will b
e
//!
forward. this can be used to check network input or dump model inputs
//!
execute before forward. this can be used to check network input or dump model
//! for debug
//!
inputs
for debug
Network
&
set_start_callback
(
const
StartCallback
&
start_callback
);
Network
&
set_start_callback
(
const
StartCallback
&
start_callback
);
//! set the finish forward
callback function, which will be execute after
//! set the finish forward
ing callback function of type FinishCallback, which will
//! forward. this can be used to dump model outputs for debug
//!
be execute after
forward. this can be used to dump model outputs for debug
Network
&
set_finish_callback
(
const
FinishCallback
&
finish_callback
);
Network
&
set_finish_callback
(
const
FinishCallback
&
finish_callback
);
//! forward the network with filled input data and fill the output data
//! forward the network with filled input data and fill the output data
...
@@ -254,33 +312,37 @@ public:
...
@@ -254,33 +312,37 @@ public:
//! waite until forward finish in sync model
//! waite until forward finish in sync model
void
wait
();
void
wait
();
//! get the input tensor name
in the order in load return
//! get the input tensor name
by index
std
::
string
get_input_name
(
size_t
index
)
const
;
std
::
string
get_input_name
(
size_t
index
)
const
;
//! get the output tensor name
in the order in load return
//! get the output tensor name
by index
std
::
string
get_output_name
(
size_t
index
)
const
;
std
::
string
get_output_name
(
size_t
index
)
const
;
//! get all the input tensor name
in the order in load return
//! get all the input tensor name
s
std
::
vector
<
std
::
string
>
get_all_input_name
()
const
;
std
::
vector
<
std
::
string
>
get_all_input_name
()
const
;
//! get all the output tensor name
in the order in load return
//! get all the output tensor name
s
std
::
vector
<
std
::
string
>
get_all_output_name
()
const
;
std
::
vector
<
std
::
string
>
get_all_output_name
()
const
;
//! set
/get
device id, default device id = 0
//! set
the network forwarding
device id, default device id = 0
Network
&
set_device_id
(
int
device_id
);
Network
&
set_device_id
(
int
device_id
);
//! get the network forwarding device id
int
get_device_id
()
const
;
int
get_device_id
()
const
;
//! set
/get
stream id, default stream id = 0
//! set
the network
stream id, default stream id = 0
Network
&
set_stream_id
(
int
stream_id
);
Network
&
set_stream_id
(
int
stream_id
);
//! get the network stream id
int
get_stream_id
()
const
;
int
get_stream_id
()
const
;
//! enable profile the network, a file will be generated
//! enable profile the network, a file will be generated
to the given path
void
enable_profile_performance
(
std
::
string
profile_file_path
);
void
enable_profile_performance
(
std
::
string
profile_file_path
);
//! get model extra info
//! get model extra info
, the extra information is packed into model by user
const
std
::
string
&
get_model_extra_info
();
const
std
::
string
&
get_model_extra_info
();
//! get device type
//! get
the network
device type
LiteDeviceType
get_device_type
()
const
;
LiteDeviceType
get_device_type
()
const
;
//! get static peak memory info showed by Graph visualization
//! get static peak memory info showed by Graph visualization
...
@@ -312,80 +374,163 @@ private:
...
@@ -312,80 +374,163 @@ private:
};
};
/*********************** MGE special network function ***************/
/*********************** MGE special network function ***************/
/*!
* @brief All the runtime configuration function is define in Runtime class, as
* a static member function
*/
class
LITE_API
Runtime
{
class
LITE_API
Runtime
{
public:
public:
//! When device is CPU, this interface will set the to be loaded model
/** @brief The multithread number setter and getter interface
//! run in multi thread mode with the given thread number.
* When device is CPU, this interface will set the network
* running in multi thread mode with the given thread number.
*
* @param dst_network the target network to set/get the thread number
* @param nr_threads the thread number set to the target network
*/
//@{
static
void
set_cpu_threads_number
(
static
void
set_cpu_threads_number
(
std
::
shared_ptr
<
Network
>
dst_network
,
size_t
nr_threads
);
std
::
shared_ptr
<
Network
>
dst_network
,
size_t
nr_threads
);
static
size_t
get_cpu_threads_number
(
std
::
shared_ptr
<
Network
>
dst_network
);
static
size_t
get_cpu_threads_number
(
std
::
shared_ptr
<
Network
>
dst_network
);
//@}
//! set threads affinity callback;
/** @brief set threads affinity callback
*
* @param dst_network the target network to set the thread affinity callback
* @param thread_affinity_callback the ThreadAffinityCallback callback to set the
* thread affinity
*/
static
void
set_runtime_thread_affinity
(
static
void
set_runtime_thread_affinity
(
std
::
shared_ptr
<
Network
>
network
,
std
::
shared_ptr
<
Network
>
network
,
const
ThreadAffinityCallback
&
thread_affinity_callback
);
const
ThreadAffinityCallback
&
thread_affinity_callback
);
//! Set cpu default mode when device is CPU, in some low computation
/** @brief Set cpu default mode when device is CPU, in some low computation
//! device or single core device, this mode will get good performace
* device or single core device, this mode will get good performace
*
* @param dst_network the target network to set/get cpu inplace model
*/
//@{
static
void
set_cpu_inplace_mode
(
std
::
shared_ptr
<
Network
>
dst_network
);
static
void
set_cpu_inplace_mode
(
std
::
shared_ptr
<
Network
>
dst_network
);
static
bool
is_cpu_inplace_mode
(
std
::
shared_ptr
<
Network
>
dst_network
);
static
bool
is_cpu_inplace_mode
(
std
::
shared_ptr
<
Network
>
dst_network
);
//@}
//! Set
use tensorrt forward
//! Set
the network forwarding use tensorrt
static
void
use_tensorrt
(
std
::
shared_ptr
<
Network
>
dst_network
);
static
void
use_tensorrt
(
std
::
shared_ptr
<
Network
>
dst_network
);
//! set opr algorithm selection strategy in the network
/** @brief set opr algorithm selection strategy in the target network
//! shared_batch_size: the batch size used by fastrun,
*
//! Non-zero value means that fastrun use this batch size
* @param dst_network the target network to set the algorithm strategy
//! regardless of the batch size of the model. Zero means
* @param strategy the algorithm strategy will set to the network, if multi
//! fastrun use batch size of the model
* strategy should set, use | operator can pack them together
//! binary_equal_between_batch: if the content of each input batch is binary
* @param shared_batch_size the batch size used by fast-run, Non-zero value means
//! equal,whether the content of each output
* that fast-run use this batch size regardless of the batch size of the model, if
//! batch is promised to be equal
* set to zero means fast-run use batch size of the model
*
* @param binary_equal_between_batch if set true means if the content of each input
* batch is binary equal, whether the content of each output batch is promised to be
* equal, otherwise not
*/
static
void
set_network_algo_policy
(
static
void
set_network_algo_policy
(
std
::
shared_ptr
<
Network
>
dst_network
,
LiteAlgoSelectStrategy
strategy
,
std
::
shared_ptr
<
Network
>
dst_network
,
LiteAlgoSelectStrategy
strategy
,
uint32_t
shared_batch_size
=
0
,
bool
binary_equal_between_batch
=
false
);
uint32_t
shared_batch_size
=
0
,
bool
binary_equal_between_batch
=
false
);
//! set workspace_limit for oprs with multiple algorithms, set
/** @brief set the opr workspace limitation in the target network, some opr
//! workspace limitation can save memory but may influence the performance
* maybe use large of workspace to get good performance, set workspace limitation
* can save memory but may influence the performance
*
* @param dst_network the target network to set/get workspace limitation
* @param workspace_limit the byte size of workspace limitation
*/
static
void
set_network_algo_workspace_limit
(
static
void
set_network_algo_workspace_limit
(
std
::
shared_ptr
<
Network
>
dst_network
,
size_t
workspace_limit
);
std
::
shared_ptr
<
Network
>
dst_network
,
size_t
workspace_limit
);
//! set the network memroy allocator, the allocator is defined by user
/** @brief set the network runtime memory Allocator, the Allocator is defined by
* user, through this method, user can implement a memory pool for network
* forwarding
*
* @param dst_network the target network
* @param user_allocator the user defined Allocator
*/
static
void
set_memory_allocator
(
static
void
set_memory_allocator
(
std
::
shared_ptr
<
Network
>
dst_network
,
std
::
shared_ptr
<
Network
>
dst_network
,
std
::
shared_ptr
<
Allocator
>
user_allocator
);
std
::
shared_ptr
<
Allocator
>
user_allocator
);
//! share the runtime memory with other network, the weights is not shared
/** @brief share the runtime memory with other network, the weights is not shared
*
* \verbatim embed:rst:leading-asterisk
*
* .. warning::
*
* the src network and the dst network can not execute in simultaneous
*
* \endverbatim
*
* @param dst_network the target network to share the runtime memory from
* src_network
* @param src_network the source network to shared runtime memory to dst_network
*/
static
void
share_runtime_memory_with
(
static
void
share_runtime_memory_with
(
std
::
shared_ptr
<
Network
>
dst_network
,
std
::
shared_ptr
<
Network
>
src_network
);
std
::
shared_ptr
<
Network
>
dst_network
,
std
::
shared_ptr
<
Network
>
src_network
);
//! Dump input/output values of all internal variables to output
/** @brief dump all input/output tensor of all operators to the output file, in txt
//! file, in txt format
* format, user can use this function to debug compute error
*
* @param dst_network the target network to dump its tensors
* @param io_txt_out_file the txt file
*/
static
void
enable_io_txt_dump
(
static
void
enable_io_txt_dump
(
std
::
shared_ptr
<
Network
>
dst_network
,
std
::
string
io_txt_out_file
);
std
::
shared_ptr
<
Network
>
dst_network
,
std
::
string
io_txt_out_file
);
//! Dump input/output values of all internal variables to output
/** @brief dump all input/output tensor of all operators to the output file, in
//! directory, in binary format
* binary format, user can use this function to debug compute error
*
* @param dst_network the target network to dump its tensors
* @param io_bin_out_dir the binary file director
*/
static
void
enable_io_bin_dump
(
static
void
enable_io_bin_dump
(
std
::
shared_ptr
<
Network
>
dst_network
,
std
::
string
io_bin_out_dir
);
std
::
shared_ptr
<
Network
>
dst_network
,
std
::
string
io_bin_out_dir
);
//! load a new network which will share weights with src network
/** @brief load a new network which will share weights with src network,
* this can reduce memory usage when user want to load the same model multi
* times
*
* @param dst_network the target network to share weights from src_network
* @param src_network the source network to shared weights to dst_network
*/
static
void
shared_weight_with_network
(
static
void
shared_weight_with_network
(
std
::
shared_ptr
<
Network
>
dst_network
,
std
::
shared_ptr
<
Network
>
dst_network
,
const
std
::
shared_ptr
<
Network
>
src_network
);
const
std
::
shared_ptr
<
Network
>
src_network
);
//! set global layout transform optimization for network
/** @brief set global layout transform optimization for network, global
* layout optimization can auto determine the layout of every operator in
* the network by profile, thus it can improve the performance of the
* network forwarding
*/
static
void
enable_global_layout_transform
(
std
::
shared_ptr
<
Network
>
network
);
static
void
enable_global_layout_transform
(
std
::
shared_ptr
<
Network
>
network
);
//! dump network after global layout transform optimization
/** @brief dump network after global layout transform optimization to the
* specific path
*/
static
void
dump_layout_transform_model
(
static
void
dump_layout_transform_model
(
std
::
shared_ptr
<
Network
>
network
,
std
::
string
optimized_model_path
);
std
::
shared_ptr
<
Network
>
network
,
std
::
string
optimized_model_path
);
//! get the model io information before model loaded by model path.
/** @brief get the model io information before model loaded by model path.
*
* @param model_path the model path to get the model IO information
* @param config the model configuration
*
* @return the model NetworkIO information
*/
static
NetworkIO
get_model_io_info
(
static
NetworkIO
get_model_io_info
(
const
std
::
string
&
model_path
,
const
Config
&
config
=
{});
const
std
::
string
&
model_path
,
const
Config
&
config
=
{});
//! get the model io information before model loaded by model memory.
/** @brief get the model io information before model loaded by model memory.
*
* @param model_mem the model memory to get the model IO information
* @param size model memory size in byte
* @param config the model configuration
*
* @return the model NetworkIO information
*/
static
NetworkIO
get_model_io_info
(
static
NetworkIO
get_model_io_info
(
const
void
*
model_mem
,
size_t
size
,
const
Config
&
config
=
{});
const
void
*
model_mem
,
size_t
size
,
const
Config
&
config
=
{});
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录