Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3cd54dd6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3cd54dd6
编写于
5月 24, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs(api/lite): add doc for lite tensor
GitOrigin-RevId: ae3799527311dd573c2540894b00064b98ec87de
上级
a891f9b3
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
401 addition
and
104 deletion
+401
-104
lite/include/lite/network.h
lite/include/lite/network.h
+0
-1
lite/include/lite/tensor.h
lite/include/lite/tensor.h
+313
-73
lite/pylite/megenginelite/tensor.py
lite/pylite/megenginelite/tensor.py
+88
-30
未找到文件。
lite/include/lite/network.h
浏览文件 @
3cd54dd6
...
@@ -135,7 +135,6 @@ struct LITE_API ExtraConfig {
...
@@ -135,7 +135,6 @@ struct LITE_API ExtraConfig {
bool
disable_configure_by_model_info
=
false
;
bool
disable_configure_by_model_info
=
false
;
};
};
/**
/**
* @brief config the network input and output item, the input and output tensor
* @brief config the network input and output item, the input and output tensor
* information will describe there
* information will describe there
...
...
lite/include/lite/tensor.h
浏览文件 @
3cd54dd6
...
@@ -9,24 +9,38 @@
...
@@ -9,24 +9,38 @@
namespace
lite
{
namespace
lite
{
/*!
/**
* \brief the simple layout description
* @struct Layout
*
* @brief Description of the way of data organized in a tensor
*/
*/
struct
LITE_API
Layout
{
struct
LITE_API
Layout
{
static
constexpr
uint32_t
MAXDIM
=
7
;
static
constexpr
uint32_t
MAXDIM
=
7
;
///< max dims
size_t
shapes
[
MAXDIM
];
size_t
shapes
[
MAXDIM
];
///< shape of each dim
size_t
ndim
=
0
;
size_t
ndim
=
0
;
///< actual number of dims
LiteDataType
data_type
=
LiteDataType
::
LITE_FLOAT
;
LiteDataType
data_type
=
LiteDataType
::
LITE_FLOAT
;
///< date type
//! get the total byte of a layout
/**
* @brief get number of elements of this Layout
*
* @return number of elements
*/
size_t
get_elem_size
()
const
;
size_t
get_elem_size
()
const
;
//! compare whether the two layout is equal
/**
* @brief compare equality of two layouts
*
* @param[in] other other layout
*
* @return result of comparation
* - true this layout is equal to other
* - flase this layout is not equal to other
*/
bool
operator
==
(
const
Layout
&
other
)
const
;
bool
operator
==
(
const
Layout
&
other
)
const
;
};
};
/*
!
/*
*
*
\
brief warpper of the MegEngine Tensor
*
@
brief warpper of the MegEngine Tensor
*
*
* \verbatim embed:rst:leading-asterisk
* \verbatim embed:rst:leading-asterisk
*
*
...
@@ -59,79 +73,182 @@ public:
...
@@ -59,79 +73,182 @@ public:
/*!
/*!
* @name Constructor
* @name Constructor
*
*
* @param device_type The desired device type of created Tensor.
* @param[in] device_type The desired device type of created Tensor.
* @param device_id The desired device id of created Tensor.
* - LITE_CPU CPU Tensor
* @param is_pinned_host Whether to use pinned memory.
* - LITE_CUDA CUDA Tensor
* @param layout The desired layout of created Tensor.
* - LITE_OPENCL OpenCL Tensor
* - LITE_ATLAS Atlas Tensor
* - LITE_NPU NPU Tensor
* - LITE_CAMBRICON Cambricon Tensor
* - LITE_AX AX Tensor
* - LITE_DEVICE_DEFAULT Tensor on default device
*
* @param[in] device_id The desired device id of created Tensor.
*
* @param[in] stream_id The desired stream id of created Tensor on disired device
*
* @param[in] backend desired backend of created Tensor.
* - LITE_DEFAULT backend is MegEngine
* - LITE_RK_NPU backend is RKNN NPU
*
* @param[in] is_pinned_host Whether to use pinned memory.
* - false use nornal memory
* - true use pinned memory[main on CUDA]
*
* @param[in] layout The desired layout of created Tensor.
*
*
*/
*/
//@{
//@{
//! Default constructor
Tensor
();
Tensor
();
//! Constructor
Tensor
(
LiteDeviceType
device_type
,
bool
is_pinned_host
=
false
);
Tensor
(
LiteDeviceType
device_type
,
bool
is_pinned_host
=
false
);
//! Constructor
Tensor
(
LiteDeviceType
device_type
,
const
Layout
&
layout
,
Tensor
(
LiteDeviceType
device_type
,
const
Layout
&
layout
,
bool
is_pinned_host
=
false
);
bool
is_pinned_host
=
false
);
//! Constructor
Tensor
(
int
device_id
,
LiteDeviceType
device_type
,
const
Layout
&
layout
=
{},
Tensor
(
int
device_id
,
LiteDeviceType
device_type
,
const
Layout
&
layout
=
{},
bool
is_pinned_host
=
false
);
bool
is_pinned_host
=
false
);
//! Constructor
Tensor
(
int
device_id
,
int
stream_id
,
LiteDeviceType
device_type
,
Tensor
(
int
device_id
,
int
stream_id
,
LiteDeviceType
device_type
,
bool
is_pinned_host
=
false
);
bool
is_pinned_host
=
false
);
//! Constructor
Tensor
(
LiteBackend
backend
,
LiteDeviceType
device_type
=
LiteDeviceType
::
LITE_CPU
,
Tensor
(
LiteBackend
backend
,
LiteDeviceType
device_type
=
LiteDeviceType
::
LITE_CPU
,
int
device_id
=
0
,
const
Layout
&
layout
=
{},
bool
is_pinned_host
=
false
);
int
device_id
=
0
,
const
Layout
&
layout
=
{},
bool
is_pinned_host
=
false
);
//@}
//@}
//! Deconstructor
~
Tensor
();
~
Tensor
();
/*!
/**
* @name Getter
* @brief Get device type of this Tensor
*
* @return device type
* - LITE_CPU CPU Tensor
* - LITE_CUDA CUDA Tensor
* - LITE_OPENCL OpenCL Tensor
* - LITE_ATLAS Atlas Tensor
* - LITE_NPU NPU Tensor
* - LITE_CAMBRICON Cambricon Tensor
* - LITE_AX AX Tensor
* - LITE_DEVICE_DEFAULT Tensor on default device
*/
*/
//@{
LiteDeviceType
get_device_type
()
const
{
return
m_device_type
;
};
LiteDeviceType
get_device_type
()
const
{
return
m_device_type
;
};
//! Get device id of this Tensor
int
get_device_id
()
const
{
return
m_device_id
;
};
int
get_device_id
()
const
{
return
m_device_id
;
};
//! Get layout of this Tensor
Layout
get_layout
()
const
{
return
m_layout
;
};
Layout
get_layout
()
const
{
return
m_layout
;
};
/**
* @brief whether Tensor is on pinned memory
*
* @return whether Tensor is on pinned memory
* - false nornal memory
* - true pinned memory
*/
bool
is_pinned_host
()
const
{
return
m_is_pinned_host
;
};
bool
is_pinned_host
()
const
{
return
m_is_pinned_host
;
};
//! which will trigger memory alloc in tensor implement
/**
* @brief Get memory address of data of this Tensor
*
* @return address pointer
*
* @note this function will trigger memory alloc in tensor implement
*/
void
*
get_memory_ptr
()
const
;
void
*
get_memory_ptr
()
const
;
//! get the memory with the offset describe in idx
/**
* @brief Get the memory with the offset describe in idx of this Tensor
*
* @param[in] idx indeces of tensor
*
* @return address pointer
*/
void
*
get_memory_ptr
(
const
std
::
vector
<
size_t
>&
idx
)
const
;
void
*
get_memory_ptr
(
const
std
::
vector
<
size_t
>&
idx
)
const
;
//!
get the tensor capacity in byte
//!
Get capacity of the Tenosr in bytes
size_t
get_tensor_total_size_in_byte
()
const
;
size_t
get_tensor_total_size_in_byte
()
const
;
//!
whether the memory of tensor is continue
//!
Check whether the memory of tensor is contigous
bool
is_continue_memory
()
const
;
bool
is_continue_memory
()
const
;
//@}
//! set layout will change the layout and reallocate memory of the tensor
/**
* @brief set layout to this Tensor
*
* @param[in] layout layout that will set into this Tensor
*
* @note this will change the layout and reallocate memory of the tensor
*/
void
set_layout
(
const
Layout
&
layout
);
void
set_layout
(
const
Layout
&
layout
);
//! use the user allocated data to reset the memory of the tensor, the
/**
//! memory will not be managed by the lite, later, the user should delete
* @brief reset layout with user alloced memory
//! it.
*
* @param[in] prepared_data user prepared data pointer
*
* @param[in] data_length_in_byte size of this memory
*
* @note the memory will not be managed by the lite, later, the user should delete it
*/
void
reset
(
void
*
prepared_data
,
size_t
data_length_in_byte
);
void
reset
(
void
*
prepared_data
,
size_t
data_length_in_byte
);
//! use the user allocated data and corresponding layout to reset the data
/**
//! and layout of the tensor, the memory will not be managed by lite, later,
* @brief reset layout with user alloced memory and corresponding layout
//! the user should delete it.
*
* @param[in] prepared_data user prepared data pointer
*
* @param[in] layout desired layout
*
* @note the memory will not be managed by the lite, later, the user should delete it
*/
void
reset
(
void
*
prepared_data
,
const
Layout
&
layout
);
void
reset
(
void
*
prepared_data
,
const
Layout
&
layout
);
//! reshape the tensor with new shape, keep the data_type the same
/**
* @brief reshape the tensor with new shape
*
* @param[in] shape target shape
*
* @note the data type will keep unchanged
*/
void
reshape
(
const
std
::
vector
<
int
>&
shape
);
void
reshape
(
const
std
::
vector
<
int
>&
shape
);
//! get a new tensor slice from the origin tensor
/**
* @brief get a slice from the origin tensor
*
* @param[in] start start idx of each dim
*
* @param[in] end end idx of each dim
*
* @param[in] step step of each dim
*
* @return ref pointer of a new Tensor
*
* @note if tensor = [[1, 2, 3], [4, 5, 6], [7, 8, 9]], start = {0, 0}, end = {2,
* 2}, step = {1, 2}. Then result = [[1, 3], [4, 6], [7, 9]]
*/
std
::
shared_ptr
<
Tensor
>
slice
(
std
::
shared_ptr
<
Tensor
>
slice
(
const
std
::
vector
<
size_t
>&
start
,
const
std
::
vector
<
size_t
>&
end
,
const
std
::
vector
<
size_t
>&
start
,
const
std
::
vector
<
size_t
>&
end
,
const
std
::
vector
<
size_t
>&
step
=
{});
const
std
::
vector
<
size_t
>&
step
=
{});
//!
set the tensor memory
with zero
//!
memset Tensor
with zero
void
fill_zero
();
void
fill_zero
();
//! copy tensor form other tensor
/**
//! @note the best way for tensor copy is just set the dst device, left
* @brief copy data from another tensor
//! layout empty, when copying the dst layout will be set the same with
*
//! src
* @param[in] src source tensor
*
* @note the best way for tensor copy is just set the dst device left layout empty.
* Layout will be set the same as src when copying
*/
void
copy_from
(
const
Tensor
&
src
);
void
copy_from
(
const
Tensor
&
src
);
//! share memory with other tensor
//! share memory with other tensor
...
@@ -144,24 +261,31 @@ public:
...
@@ -144,24 +261,31 @@ public:
friend
class
TensorHelper
;
friend
class
TensorHelper
;
private:
private:
std
::
shared_ptr
<
TensorImplBase
>
m_tensor_impl
;
std
::
shared_ptr
<
TensorImplBase
>
m_tensor_impl
;
///< tensor implementation.
bool
m_is_pinned_host
=
//! flag whether the storage of the tensor is pinned, this is only used
false
;
///< flag whether the storage of the tensor is pinned, this is only
//! when the compnode is not in CPU
///< used when the compnode is not in CPU.
bool
m_is_pinned_host
=
false
;
int
m_device_id
=
0
;
///< device id of this Tensor.
int
m_device_id
=
0
;
Layout
m_layout
;
///< layout of this Tensor.
Layout
m_layout
;
LiteDeviceType
m_device_type
=
//! the device of the tensor should not be changed after the tensor has
LiteDeviceType
::
LITE_CPU
;
///< devie type of this Tensor. should not change
//! constructed
///< after constructing.
LiteDeviceType
m_device_type
=
LiteDeviceType
::
LITE_CPU
;
};
};
/**
/**
* \brief a class can hold any type data, but not check whether the visit type
* @class LiteAny
* is valid
*
* @brief a class can hold any type data
*
* @note the visit type is valide will not be checked
*/
*/
class
LITE_API
LiteAny
{
class
LITE_API
LiteAny
{
public:
public:
/**
* @enum Type
*
* @brief enum for data type
*/
enum
Type
{
enum
Type
{
STRING
=
0
,
STRING
=
0
,
INT32
=
1
,
INT32
=
1
,
...
@@ -175,45 +299,128 @@ public:
...
@@ -175,45 +299,128 @@ public:
FLOAT
=
9
,
FLOAT
=
9
,
NONE_SUPPORT
=
10
,
NONE_SUPPORT
=
10
,
};
};
/**
* @class HolderBase
*
* @brief Base class for holding any type of data
*/
class
HolderBase
{
public:
/**
* @brief virtual deconstructor
*/
virtual
~
HolderBase
()
=
default
;
/**
* @brief clone data
*
* @return a new ref pointer of the data
*
* @note pure virtual interface
*/
virtual
std
::
shared_ptr
<
HolderBase
>
clone
()
=
0
;
};
/**
* @class AnyHolder
*
* @brief template class that holds any type of data
*/
template
<
class
T
>
class
AnyHolder
:
public
HolderBase
{
public:
/**
* @brief default constructor
*/
AnyHolder
(
const
T
value
)
:
m_value
(
value
)
{}
/**
* @brief clone data of this holder
*
* @return a ref pointer of m_value
*/
virtual
std
::
shared_ptr
<
HolderBase
>
clone
()
override
{
return
std
::
make_shared
<
AnyHolder
>
(
m_value
);
}
public:
T
m_value
;
///< value
};
/**
* @brief default constructor
*/
LiteAny
()
=
default
;
LiteAny
()
=
default
;
/**
* @brief constructor with value of any type
*
* @param[in] value data
*/
template
<
class
T
>
template
<
class
T
>
LiteAny
(
T
value
)
:
m_holder
(
new
AnyHolder
<
T
>
(
value
))
{
LiteAny
(
T
value
)
:
m_holder
(
new
AnyHolder
<
T
>
(
value
))
{
m_type
=
get_type
<
T
>
();
m_type
=
get_type
<
T
>
();
}
}
/**
* @brief copy constructor
*
* @param[in] any data
*/
LiteAny
(
const
LiteAny
&
any
)
{
LiteAny
(
const
LiteAny
&
any
)
{
m_holder
=
any
.
m_holder
->
clone
();
m_holder
=
any
.
m_holder
->
clone
();
m_type
=
any
.
m_type
;
m_type
=
any
.
m_type
;
}
}
/**
* @brief assign operator overloading
*
* @param[in] any data
*/
LiteAny
&
operator
=
(
const
LiteAny
&
any
)
{
LiteAny
&
operator
=
(
const
LiteAny
&
any
)
{
m_holder
=
any
.
m_holder
->
clone
();
m_holder
=
any
.
m_holder
->
clone
();
m_type
=
any
.
m_type
;
m_type
=
any
.
m_type
;
return
*
this
;
return
*
this
;
}
}
/**
* @brief get data type of this hold
*
* @return type of data
* - STRING
* - INT32
* - UINT32
* - UINT8
* - INT8
* - INT64
* - UINT64
* - BOOL
* - VOID_PTR
* - FLOAT
* - NONE_SUPPORT
*/
template
<
class
T
>
template
<
class
T
>
Type
get_type
()
const
;
Type
get_type
()
const
;
class
HolderBase
{
/**
public:
* @brief check whether type mismatch
virtual
~
HolderBase
()
=
default
;
*
virtual
std
::
shared_ptr
<
HolderBase
>
clone
()
=
0
;
* @param[in] expect expected type
};
*
* @param[in] get got type
template
<
class
T
>
*
class
AnyHolder
:
public
HolderBase
{
* @note if type is miss matching, it will throw
public:
*/
AnyHolder
(
const
T
value
)
:
m_value
(
value
)
{}
virtual
std
::
shared_ptr
<
HolderBase
>
clone
()
override
{
return
std
::
make_shared
<
AnyHolder
>
(
m_value
);
}
public:
T
m_value
;
};
//! if type is miss matching, it will throw
void
type_missmatch
(
size_t
expect
,
size_t
get
)
const
;
void
type_missmatch
(
size_t
expect
,
size_t
get
)
const
;
/**
* @brief cast with type safty
*
* @return casted type
*
* @note if type is miss matching, it will throw
*/
template
<
class
T
>
template
<
class
T
>
T
safe_cast
()
const
{
T
safe_cast
()
const
{
if
(
get_type
<
T
>
()
!=
m_type
)
{
if
(
get_type
<
T
>
()
!=
m_type
)
{
...
@@ -221,6 +428,14 @@ public:
...
@@ -221,6 +428,14 @@ public:
}
}
return
static_cast
<
LiteAny
::
AnyHolder
<
T
>*>
(
m_holder
.
get
())
->
m_value
;
return
static_cast
<
LiteAny
::
AnyHolder
<
T
>*>
(
m_holder
.
get
())
->
m_value
;
}
}
/**
* @brief check whether can cast to one kind of type
*
* @return successful or not
* - true successful
* - false failed
*/
template
<
class
T
>
template
<
class
T
>
bool
try_cast
()
const
{
bool
try_cast
()
const
{
if
(
get_type
<
T
>
()
==
m_type
)
{
if
(
get_type
<
T
>
()
==
m_type
)
{
...
@@ -229,22 +444,47 @@ public:
...
@@ -229,22 +444,47 @@ public:
return
false
;
return
false
;
}
}
}
}
//! only check the storage type and the visit type length, so it's not safe
/**
* @brief unsafe cast to void*
*
* @return pointer to hold data
*
* @note only check the storage type and the visit type length, so it's not safe
*/
void
*
cast_void_ptr
()
const
{
void
*
cast_void_ptr
()
const
{
return
&
static_cast
<
LiteAny
::
AnyHolder
<
char
>*>
(
m_holder
.
get
())
->
m_value
;
return
&
static_cast
<
LiteAny
::
AnyHolder
<
char
>*>
(
m_holder
.
get
())
->
m_value
;
}
}
private:
private:
std
::
shared_ptr
<
HolderBase
>
m_holder
;
std
::
shared_ptr
<
HolderBase
>
m_holder
;
///< holder member
Type
m_type
=
NONE_SUPPORT
;
Type
m_type
=
NONE_SUPPORT
;
///< type member
};
};
/*********************** special tensor function ***************/
/**
* @class TensorUtils
*
* @brief provide special tensor tool functions
*/
class
LITE_API
TensorUtils
{
class
LITE_API
TensorUtils
{
public:
public:
//! concat all the input tensor to one on the specified dim, the result
//! tensor reside in dst_device_id of dst_device, if dst_device is
/**
//! LITE_DEVICE_DEFAULT, the device will get from the first tensor
* @brief concat all the input tensor to one on the specified dim.
*
* @param[in] tensors input tensors
*
* @param[in] dim specified dim
*
* @param[in] dst_device type of output tensor
*
* @param[in] dst_device_id id of output tensor
*
* @return concated tensor
*
* @note the result tensor reside in dst_device_id of dst_device, if dst_device is
* LITE_DEVICE_DEFAULT, the device will get from the first tensor
*/
static
std
::
shared_ptr
<
Tensor
>
concat
(
static
std
::
shared_ptr
<
Tensor
>
concat
(
const
std
::
vector
<
Tensor
>&
tensors
,
int
dim
,
const
std
::
vector
<
Tensor
>&
tensors
,
int
dim
,
LiteDeviceType
dst_device
=
LiteDeviceType
::
LITE_DEVICE_DEFAULT
,
LiteDeviceType
dst_device
=
LiteDeviceType
::
LITE_DEVICE_DEFAULT
,
...
...
lite/pylite/megenginelite/tensor.py
浏览文件 @
3cd54dd6
...
@@ -53,7 +53,25 @@ _lite_dtypes_to_ctype = {
...
@@ -53,7 +53,25 @@ _lite_dtypes_to_ctype = {
class
LiteLayout
(
Structure
):
class
LiteLayout
(
Structure
):
"""
"""
the simple layout description
Description of layout using in Lite. A Lite layout will be totally defined
by shape and data type.
Args:
shape: the shape of data.
dtype: data type.
Note:
Dims of shape should be less than 8. The supported data type defines at
LiteDataType
Examples:
.. code-block:: python
import numpy as np
layout = LiteLayout([1, 4, 8, 8], LiteDataType.LITE_FLOAT)
assert(layout.shape()) == [1, 4, 8, 8]
assert(layout.dtype()) == LiteDataType.LITE_FLOAT
"""
"""
_fields_
=
[
_fields_
=
[
...
@@ -113,10 +131,14 @@ class _LiteTensorDesc(Structure):
...
@@ -113,10 +131,14 @@ class _LiteTensorDesc(Structure):
"""
"""
warpper of the MegEngine Tensor
warpper of the MegEngine Tensor
:is_pinned_host: when set, the storage memory of the tensor is pinned memory,
Args:
this is used to Optimize the H2D or D2H memory copy, if the device or layout
is_pinned_host: when set, the storage memory of the tensor is pinned
is not set, when copy form other device(CUDA) tensor, this tensor
memory. This is used to Optimize the H2D or D2H memory copy, if the
will be automatically set to pinned tensor
device or layout is not set, when copy form other device(CUDA)
tensor, this tensor will be automatically set to pinned tensor
layout(LiteLayout): layout of this tensor
device_type: type of device
device_id: id of device
"""
"""
_fields_
=
[
_fields_
=
[
...
@@ -144,7 +166,7 @@ class _LiteTensorDesc(Structure):
...
@@ -144,7 +166,7 @@ class _LiteTensorDesc(Structure):
class
_TensorAPI
(
_LiteCObjBase
):
class
_TensorAPI
(
_LiteCObjBase
):
"""
"""
get the api
from the lib
Get the API
from the lib
"""
"""
_api_
=
[
_api_
=
[
...
@@ -183,7 +205,22 @@ class _TensorAPI(_LiteCObjBase):
...
@@ -183,7 +205,22 @@ class _TensorAPI(_LiteCObjBase):
class
LiteTensor
(
object
):
class
LiteTensor
(
object
):
"""
"""
the tensor to hold a block of data
Description of a block of data with neccessary information.
Args:
layout: layout of Tensor
device_type: device type of Tensor
device_id: device id of Tensor
is_pinned_host: when set, the storage memory of the tensor is pinned
memory. This is used to Optimize the H2D or D2H memory copy, if the
device or layout is not set, when copy form other device(CUDA)
tensor, this tensor will be automatically set to pinned tensor
shapes: the shape of data
dtype: data type
Note:
Dims of shape should be less than 8. The supported data type defines at
LiteDataType
"""
"""
_api
=
_TensorAPI
().
_lib
_api
=
_TensorAPI
().
_lib
...
@@ -197,10 +234,6 @@ class LiteTensor(object):
...
@@ -197,10 +234,6 @@ class LiteTensor(object):
shapes
=
None
,
shapes
=
None
,
dtype
=
None
,
dtype
=
None
,
):
):
"""
create a Tensor with layout, device, is_pinned_host or shapes, dtype,
device_type, device_id, is_pinned_host param
"""
self
.
_tensor
=
_Ctensor
()
self
.
_tensor
=
_Ctensor
()
self
.
_layout
=
LiteLayout
()
self
.
_layout
=
LiteLayout
()
if
layout
is
not
None
:
if
layout
is
not
None
:
...
@@ -232,7 +265,11 @@ class LiteTensor(object):
...
@@ -232,7 +265,11 @@ class LiteTensor(object):
def
share_memory_with
(
self
,
src_tensor
):
def
share_memory_with
(
self
,
src_tensor
):
"""
"""
share the same memory with the src_tensor, the self memory will be freed
share the same memory with the ``src_tensor``, the self memory will be
freed
Args:
src_tensor: the source tensor that will share memory with this tensor
"""
"""
assert
isinstance
(
src_tensor
,
LiteTensor
)
assert
isinstance
(
src_tensor
,
LiteTensor
)
self
.
_api
.
LITE_tensor_share_memory_with
(
self
.
_tensor
,
src_tensor
.
_tensor
)
self
.
_api
.
LITE_tensor_share_memory_with
(
self
.
_tensor
,
src_tensor
.
_tensor
)
...
@@ -265,7 +302,7 @@ class LiteTensor(object):
...
@@ -265,7 +302,7 @@ class LiteTensor(object):
@
property
@
property
def
device_type
(
self
):
def
device_type
(
self
):
"""
"""
get device of the tensor
get device
type
of the tensor
"""
"""
device_type
=
c_int
()
device_type
=
c_int
()
self
.
_api
.
LITE_get_tensor_device_type
(
self
.
_tensor
,
byref
(
device_type
))
self
.
_api
.
LITE_get_tensor_device_type
(
self
.
_tensor
,
byref
(
device_type
))
...
@@ -320,6 +357,9 @@ class LiteTensor(object):
...
@@ -320,6 +357,9 @@ class LiteTensor(object):
def
copy_from
(
self
,
src_tensor
):
def
copy_from
(
self
,
src_tensor
):
"""
"""
copy memory form the src_tensor
copy memory form the src_tensor
Args:
src_tensor: source tensor
"""
"""
assert
isinstance
(
src_tensor
,
LiteTensor
)
assert
isinstance
(
src_tensor
,
LiteTensor
)
self
.
_api
.
LITE_tensor_copy
(
self
.
_tensor
,
src_tensor
.
_tensor
)
self
.
_api
.
LITE_tensor_copy
(
self
.
_tensor
,
src_tensor
.
_tensor
)
...
@@ -327,8 +367,10 @@ class LiteTensor(object):
...
@@ -327,8 +367,10 @@ class LiteTensor(object):
def
reshape
(
self
,
shape
):
def
reshape
(
self
,
shape
):
"""
"""
reshape the tensor with data not change, only change the shape
reshape the tensor with data not change.
:param shape: int arrary of dst_shape
Args:
shape: target shape
"""
"""
shape
=
list
(
shape
)
shape
=
list
(
shape
)
length
=
len
(
shape
)
length
=
len
(
shape
)
...
@@ -339,9 +381,11 @@ class LiteTensor(object):
...
@@ -339,9 +381,11 @@ class LiteTensor(object):
def
slice
(
self
,
start
,
end
,
step
=
None
):
def
slice
(
self
,
start
,
end
,
step
=
None
):
"""
"""
slice the tensor with gaven start, end, step
slice the tensor with gaven start, end, step
:param start: silce begin index of each dim
:param end: silce end index of each dim
Args:
:param step: silce step of each dim
start: silce begin index of each dim
end: silce end index of each dim
step: silce step of each dim
"""
"""
start
=
list
(
start
)
start
=
list
(
start
)
end
=
list
(
end
)
end
=
list
(
end
)
...
@@ -357,7 +401,7 @@ class LiteTensor(object):
...
@@ -357,7 +401,7 @@ class LiteTensor(object):
c_step
=
(
c_size_t
*
length
)(
*
step
)
c_step
=
(
c_size_t
*
length
)(
*
step
)
slice_tensor
=
LiteTensor
()
slice_tensor
=
LiteTensor
()
self
.
_api
.
LITE_tensor_slice
(
self
.
_api
.
LITE_tensor_slice
(
self
.
_tensor
,
c_start
,
c_end
,
c_step
,
length
,
byref
(
slice_tensor
.
_tensor
)
self
.
_tensor
,
c_start
,
c_end
,
c_step
,
length
,
byref
(
slice_tensor
.
_tensor
)
,
)
)
slice_tensor
.
update
()
slice_tensor
.
update
()
return
slice_tensor
return
slice_tensor
...
@@ -373,7 +417,9 @@ class LiteTensor(object):
...
@@ -373,7 +417,9 @@ class LiteTensor(object):
def
set_data_by_share
(
self
,
data
,
length
=
0
,
layout
=
None
):
def
set_data_by_share
(
self
,
data
,
length
=
0
,
layout
=
None
):
"""
"""
share the data to the tensor
share the data to the tensor
param data: the data will shared to the tensor, it should be a
Args:
data: the data will shared to the tensor, it should be a
numpy.ndarray or ctypes data
numpy.ndarray or ctypes data
"""
"""
if
isinstance
(
data
,
np
.
ndarray
):
if
isinstance
(
data
,
np
.
ndarray
):
...
@@ -400,9 +446,13 @@ class LiteTensor(object):
...
@@ -400,9 +446,13 @@ class LiteTensor(object):
def
set_data_by_copy
(
self
,
data
,
data_length
=
0
,
layout
=
None
):
def
set_data_by_copy
(
self
,
data
,
data_length
=
0
,
layout
=
None
):
"""
"""
copy the data to the tensor, the memory of the tensor must be continue
copy the data to the tensor
param data: the data to copy to tensor, it should be list,
numpy.ndarraya or ctypes with length
Args:
data: the data to copy to tensor, it should be list, numpy.ndarraya
or ctypes with length
data_length: length of data in bytes
layout: layout of data
"""
"""
if
layout
is
not
None
:
if
layout
is
not
None
:
self
.
layout
=
layout
self
.
layout
=
layout
...
@@ -440,8 +490,11 @@ class LiteTensor(object):
...
@@ -440,8 +490,11 @@ class LiteTensor(object):
def
get_data_by_share
(
self
):
def
get_data_by_share
(
self
):
"""
"""
get the data in the tensor, add share the data with a new numpy, and
get the data in the tensor, add share the data with a new numpy, and
return the numpy arrray, be careful, the data in numpy is valid before
return the numpy arrray
the tensor memory is write again, such as LiteNetwok forward next time.
Note:
Be careful, the data in numpy is valid before the tensor memory is
write again, such as LiteNetwok forward next time.
"""
"""
self
.
update
()
self
.
update
()
...
@@ -490,7 +543,9 @@ def LiteTensorConcat(
...
@@ -490,7 +543,9 @@ def LiteTensorConcat(
tensors
,
dim
,
device_type
=
LiteDeviceType
.
LITE_DEVICE_DEFAULT
,
device_id
=-
1
tensors
,
dim
,
device_type
=
LiteDeviceType
.
LITE_DEVICE_DEFAULT
,
device_id
=-
1
):
):
"""
"""
concat tensor in input dim to one tensor
concat tensors at expected dim to one tensor
Args:
dim : the dim to act concat
dim : the dim to act concat
device_type: the result tensor device type
device_type: the result tensor device type
device_id: the result tensor device id
device_id: the result tensor device id
...
@@ -515,6 +570,9 @@ def LiteTensorConcat(
...
@@ -515,6 +570,9 @@ def LiteTensorConcat(
def
lite_dtype_2_numpy
(
dtype
):
def
lite_dtype_2_numpy
(
dtype
):
"""
"""
convert lite dtype to corresponding numpy dtype
convert lite dtype to corresponding numpy dtype
Args:
dtype(LiteDataType): source dtype
"""
"""
assert
isinstance
(
assert
isinstance
(
dtype
,
LiteDataType
dtype
,
LiteDataType
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录