提交 3cd54dd6 编写于 作者: M Megvii Engine Team

docs(api/lite): add doc for lite tensor

GitOrigin-RevId: ae3799527311dd573c2540894b00064b98ec87de
上级 a891f9b3
...@@ -135,7 +135,6 @@ struct LITE_API ExtraConfig { ...@@ -135,7 +135,6 @@ struct LITE_API ExtraConfig {
bool disable_configure_by_model_info = false; bool disable_configure_by_model_info = false;
}; };
/** /**
* @brief config the network input and output item, the input and output tensor * @brief config the network input and output item, the input and output tensor
* information will describe there * information will describe there
......
...@@ -9,24 +9,38 @@ ...@@ -9,24 +9,38 @@
namespace lite { namespace lite {
/*! /**
* \brief the simple layout description * @struct Layout
*
* @brief Description of the way of data organized in a tensor
*/ */
struct LITE_API Layout { struct LITE_API Layout {
static constexpr uint32_t MAXDIM = 7; static constexpr uint32_t MAXDIM = 7; ///< max dims
size_t shapes[MAXDIM]; size_t shapes[MAXDIM]; ///< shape of each dim
size_t ndim = 0; size_t ndim = 0; ///< actual number of dims
LiteDataType data_type = LiteDataType::LITE_FLOAT; LiteDataType data_type = LiteDataType::LITE_FLOAT; ///< date type
//! get the total byte of a layout /**
* @brief get number of elements of this Layout
*
* @return number of elements
*/
size_t get_elem_size() const; size_t get_elem_size() const;
//! compare whether the two layout is equal /**
* @brief compare equality of two layouts
*
* @param[in] other other layout
*
* @return result of comparation
* - true this layout is equal to other
* - flase this layout is not equal to other
*/
bool operator==(const Layout& other) const; bool operator==(const Layout& other) const;
}; };
/*! /**
* \brief warpper of the MegEngine Tensor * @brief warpper of the MegEngine Tensor
* *
* \verbatim embed:rst:leading-asterisk * \verbatim embed:rst:leading-asterisk
* *
...@@ -59,79 +73,182 @@ public: ...@@ -59,79 +73,182 @@ public:
/*! /*!
* @name Constructor * @name Constructor
* *
* @param device_type The desired device type of created Tensor. * @param[in] device_type The desired device type of created Tensor.
* @param device_id The desired device id of created Tensor. * - LITE_CPU CPU Tensor
* @param is_pinned_host Whether to use pinned memory. * - LITE_CUDA CUDA Tensor
* @param layout The desired layout of created Tensor. * - LITE_OPENCL OpenCL Tensor
* - LITE_ATLAS Atlas Tensor
* - LITE_NPU NPU Tensor
* - LITE_CAMBRICON Cambricon Tensor
* - LITE_AX AX Tensor
* - LITE_DEVICE_DEFAULT Tensor on default device
*
* @param[in] device_id The desired device id of created Tensor.
*
* @param[in] stream_id The desired stream id of created Tensor on disired device
*
* @param[in] backend desired backend of created Tensor.
* - LITE_DEFAULT backend is MegEngine
* - LITE_RK_NPU backend is RKNN NPU
*
* @param[in] is_pinned_host Whether to use pinned memory.
* - false use nornal memory
* - true use pinned memory[main on CUDA]
*
* @param[in] layout The desired layout of created Tensor.
* *
*/ */
//@{ //@{
//! Default constructor
Tensor(); Tensor();
//! Constructor
Tensor(LiteDeviceType device_type, bool is_pinned_host = false); Tensor(LiteDeviceType device_type, bool is_pinned_host = false);
//! Constructor
Tensor(LiteDeviceType device_type, const Layout& layout, Tensor(LiteDeviceType device_type, const Layout& layout,
bool is_pinned_host = false); bool is_pinned_host = false);
//! Constructor
Tensor(int device_id, LiteDeviceType device_type, const Layout& layout = {}, Tensor(int device_id, LiteDeviceType device_type, const Layout& layout = {},
bool is_pinned_host = false); bool is_pinned_host = false);
//! Constructor
Tensor(int device_id, int stream_id, LiteDeviceType device_type, Tensor(int device_id, int stream_id, LiteDeviceType device_type,
bool is_pinned_host = false); bool is_pinned_host = false);
//! Constructor
Tensor(LiteBackend backend, LiteDeviceType device_type = LiteDeviceType::LITE_CPU, Tensor(LiteBackend backend, LiteDeviceType device_type = LiteDeviceType::LITE_CPU,
int device_id = 0, const Layout& layout = {}, bool is_pinned_host = false); int device_id = 0, const Layout& layout = {}, bool is_pinned_host = false);
//@} //@}
//! Deconstructor
~Tensor(); ~Tensor();
/*! /**
* @name Getter * @brief Get device type of this Tensor
*
* @return device type
* - LITE_CPU CPU Tensor
* - LITE_CUDA CUDA Tensor
* - LITE_OPENCL OpenCL Tensor
* - LITE_ATLAS Atlas Tensor
* - LITE_NPU NPU Tensor
* - LITE_CAMBRICON Cambricon Tensor
* - LITE_AX AX Tensor
* - LITE_DEVICE_DEFAULT Tensor on default device
*/ */
//@{
LiteDeviceType get_device_type() const { return m_device_type; }; LiteDeviceType get_device_type() const { return m_device_type; };
//! Get device id of this Tensor
int get_device_id() const { return m_device_id; }; int get_device_id() const { return m_device_id; };
//! Get layout of this Tensor
Layout get_layout() const { return m_layout; }; Layout get_layout() const { return m_layout; };
/**
* @brief whether Tensor is on pinned memory
*
* @return whether Tensor is on pinned memory
* - false nornal memory
* - true pinned memory
*/
bool is_pinned_host() const { return m_is_pinned_host; }; bool is_pinned_host() const { return m_is_pinned_host; };
//! which will trigger memory alloc in tensor implement /**
* @brief Get memory address of data of this Tensor
*
* @return address pointer
*
* @note this function will trigger memory alloc in tensor implement
*/
void* get_memory_ptr() const; void* get_memory_ptr() const;
//! get the memory with the offset describe in idx /**
* @brief Get the memory with the offset describe in idx of this Tensor
*
* @param[in] idx indeces of tensor
*
* @return address pointer
*/
void* get_memory_ptr(const std::vector<size_t>& idx) const; void* get_memory_ptr(const std::vector<size_t>& idx) const;
//! get the tensor capacity in byte //! Get capacity of the Tenosr in bytes
size_t get_tensor_total_size_in_byte() const; size_t get_tensor_total_size_in_byte() const;
//! whether the memory of tensor is continue //! Check whether the memory of tensor is contigous
bool is_continue_memory() const; bool is_continue_memory() const;
//@}
//! set layout will change the layout and reallocate memory of the tensor /**
* @brief set layout to this Tensor
*
* @param[in] layout layout that will set into this Tensor
*
* @note this will change the layout and reallocate memory of the tensor
*/
void set_layout(const Layout& layout); void set_layout(const Layout& layout);
//! use the user allocated data to reset the memory of the tensor, the /**
//! memory will not be managed by the lite, later, the user should delete * @brief reset layout with user alloced memory
//! it. *
* @param[in] prepared_data user prepared data pointer
*
* @param[in] data_length_in_byte size of this memory
*
* @note the memory will not be managed by the lite, later, the user should delete it
*/
void reset(void* prepared_data, size_t data_length_in_byte); void reset(void* prepared_data, size_t data_length_in_byte);
//! use the user allocated data and corresponding layout to reset the data /**
//! and layout of the tensor, the memory will not be managed by lite, later, * @brief reset layout with user alloced memory and corresponding layout
//! the user should delete it. *
* @param[in] prepared_data user prepared data pointer
*
* @param[in] layout desired layout
*
* @note the memory will not be managed by the lite, later, the user should delete it
*/
void reset(void* prepared_data, const Layout& layout); void reset(void* prepared_data, const Layout& layout);
//! reshape the tensor with new shape, keep the data_type the same /**
* @brief reshape the tensor with new shape
*
* @param[in] shape target shape
*
* @note the data type will keep unchanged
*/
void reshape(const std::vector<int>& shape); void reshape(const std::vector<int>& shape);
//! get a new tensor slice from the origin tensor /**
* @brief get a slice from the origin tensor
*
* @param[in] start start idx of each dim
*
* @param[in] end end idx of each dim
*
* @param[in] step step of each dim
*
* @return ref pointer of a new Tensor
*
* @note if tensor = [[1, 2, 3], [4, 5, 6], [7, 8, 9]], start = {0, 0}, end = {2,
* 2}, step = {1, 2}. Then result = [[1, 3], [4, 6], [7, 9]]
*/
std::shared_ptr<Tensor> slice( std::shared_ptr<Tensor> slice(
const std::vector<size_t>& start, const std::vector<size_t>& end, const std::vector<size_t>& start, const std::vector<size_t>& end,
const std::vector<size_t>& step = {}); const std::vector<size_t>& step = {});
//! set the tensor memory with zero //! memset Tensor with zero
void fill_zero(); void fill_zero();
//! copy tensor form other tensor /**
//! @note the best way for tensor copy is just set the dst device, left * @brief copy data from another tensor
//! layout empty, when copying the dst layout will be set the same with *
//! src * @param[in] src source tensor
*
* @note the best way for tensor copy is just set the dst device left layout empty.
* Layout will be set the same as src when copying
*/
void copy_from(const Tensor& src); void copy_from(const Tensor& src);
//! share memory with other tensor //! share memory with other tensor
...@@ -144,24 +261,31 @@ public: ...@@ -144,24 +261,31 @@ public:
friend class TensorHelper; friend class TensorHelper;
private: private:
std::shared_ptr<TensorImplBase> m_tensor_impl; std::shared_ptr<TensorImplBase> m_tensor_impl; ///< tensor implementation.
bool m_is_pinned_host =
//! flag whether the storage of the tensor is pinned, this is only used false; ///< flag whether the storage of the tensor is pinned, this is only
//! when the compnode is not in CPU ///< used when the compnode is not in CPU.
bool m_is_pinned_host = false; int m_device_id = 0; ///< device id of this Tensor.
int m_device_id = 0; Layout m_layout; ///< layout of this Tensor.
Layout m_layout; LiteDeviceType m_device_type =
//! the device of the tensor should not be changed after the tensor has LiteDeviceType::LITE_CPU; ///< devie type of this Tensor. should not change
//! constructed ///< after constructing.
LiteDeviceType m_device_type = LiteDeviceType::LITE_CPU;
}; };
/** /**
* \brief a class can hold any type data, but not check whether the visit type * @class LiteAny
* is valid *
* @brief a class can hold any type data
*
* @note the visit type is valide will not be checked
*/ */
class LITE_API LiteAny { class LITE_API LiteAny {
public: public:
/**
* @enum Type
*
* @brief enum for data type
*/
enum Type { enum Type {
STRING = 0, STRING = 0,
INT32 = 1, INT32 = 1,
...@@ -175,45 +299,128 @@ public: ...@@ -175,45 +299,128 @@ public:
FLOAT = 9, FLOAT = 9,
NONE_SUPPORT = 10, NONE_SUPPORT = 10,
}; };
/**
* @class HolderBase
*
* @brief Base class for holding any type of data
*/
class HolderBase {
public:
/**
* @brief virtual deconstructor
*/
virtual ~HolderBase() = default;
/**
* @brief clone data
*
* @return a new ref pointer of the data
*
* @note pure virtual interface
*/
virtual std::shared_ptr<HolderBase> clone() = 0;
};
/**
* @class AnyHolder
*
* @brief template class that holds any type of data
*/
template <class T>
class AnyHolder : public HolderBase {
public:
/**
* @brief default constructor
*/
AnyHolder(const T value) : m_value(value) {}
/**
* @brief clone data of this holder
*
* @return a ref pointer of m_value
*/
virtual std::shared_ptr<HolderBase> clone() override {
return std::make_shared<AnyHolder>(m_value);
}
public:
T m_value; ///< value
};
/**
* @brief default constructor
*/
LiteAny() = default; LiteAny() = default;
/**
* @brief constructor with value of any type
*
* @param[in] value data
*/
template <class T> template <class T>
LiteAny(T value) : m_holder(new AnyHolder<T>(value)) { LiteAny(T value) : m_holder(new AnyHolder<T>(value)) {
m_type = get_type<T>(); m_type = get_type<T>();
} }
/**
* @brief copy constructor
*
* @param[in] any data
*/
LiteAny(const LiteAny& any) { LiteAny(const LiteAny& any) {
m_holder = any.m_holder->clone(); m_holder = any.m_holder->clone();
m_type = any.m_type; m_type = any.m_type;
} }
/**
* @brief assign operator overloading
*
* @param[in] any data
*/
LiteAny& operator=(const LiteAny& any) { LiteAny& operator=(const LiteAny& any) {
m_holder = any.m_holder->clone(); m_holder = any.m_holder->clone();
m_type = any.m_type; m_type = any.m_type;
return *this; return *this;
} }
/**
* @brief get data type of this hold
*
* @return type of data
* - STRING
* - INT32
* - UINT32
* - UINT8
* - INT8
* - INT64
* - UINT64
* - BOOL
* - VOID_PTR
* - FLOAT
* - NONE_SUPPORT
*/
template <class T> template <class T>
Type get_type() const; Type get_type() const;
class HolderBase { /**
public: * @brief check whether type mismatch
virtual ~HolderBase() = default; *
virtual std::shared_ptr<HolderBase> clone() = 0; * @param[in] expect expected type
}; *
* @param[in] get got type
template <class T> *
class AnyHolder : public HolderBase { * @note if type is miss matching, it will throw
public: */
AnyHolder(const T value) : m_value(value) {}
virtual std::shared_ptr<HolderBase> clone() override {
return std::make_shared<AnyHolder>(m_value);
}
public:
T m_value;
};
//! if type is miss matching, it will throw
void type_missmatch(size_t expect, size_t get) const; void type_missmatch(size_t expect, size_t get) const;
/**
* @brief cast with type safty
*
* @return casted type
*
* @note if type is miss matching, it will throw
*/
template <class T> template <class T>
T safe_cast() const { T safe_cast() const {
if (get_type<T>() != m_type) { if (get_type<T>() != m_type) {
...@@ -221,6 +428,14 @@ public: ...@@ -221,6 +428,14 @@ public:
} }
return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value; return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value;
} }
/**
* @brief check whether can cast to one kind of type
*
* @return successful or not
* - true successful
* - false failed
*/
template <class T> template <class T>
bool try_cast() const { bool try_cast() const {
if (get_type<T>() == m_type) { if (get_type<T>() == m_type) {
...@@ -229,22 +444,47 @@ public: ...@@ -229,22 +444,47 @@ public:
return false; return false;
} }
} }
//! only check the storage type and the visit type length, so it's not safe
/**
* @brief unsafe cast to void*
*
* @return pointer to hold data
*
* @note only check the storage type and the visit type length, so it's not safe
*/
void* cast_void_ptr() const { void* cast_void_ptr() const {
return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value; return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value;
} }
private: private:
std::shared_ptr<HolderBase> m_holder; std::shared_ptr<HolderBase> m_holder; ///< holder member
Type m_type = NONE_SUPPORT; Type m_type = NONE_SUPPORT; ///< type member
}; };
/*********************** special tensor function ***************/ /**
* @class TensorUtils
*
* @brief provide special tensor tool functions
*/
class LITE_API TensorUtils { class LITE_API TensorUtils {
public: public:
//! concat all the input tensor to one on the specified dim, the result
//! tensor reside in dst_device_id of dst_device, if dst_device is /**
//! LITE_DEVICE_DEFAULT, the device will get from the first tensor * @brief concat all the input tensor to one on the specified dim.
*
* @param[in] tensors input tensors
*
* @param[in] dim specified dim
*
* @param[in] dst_device type of output tensor
*
* @param[in] dst_device_id id of output tensor
*
* @return concated tensor
*
* @note the result tensor reside in dst_device_id of dst_device, if dst_device is
* LITE_DEVICE_DEFAULT, the device will get from the first tensor
*/
static std::shared_ptr<Tensor> concat( static std::shared_ptr<Tensor> concat(
const std::vector<Tensor>& tensors, int dim, const std::vector<Tensor>& tensors, int dim,
LiteDeviceType dst_device = LiteDeviceType::LITE_DEVICE_DEFAULT, LiteDeviceType dst_device = LiteDeviceType::LITE_DEVICE_DEFAULT,
......
...@@ -53,7 +53,25 @@ _lite_dtypes_to_ctype = { ...@@ -53,7 +53,25 @@ _lite_dtypes_to_ctype = {
class LiteLayout(Structure): class LiteLayout(Structure):
""" """
the simple layout description Description of layout using in Lite. A Lite layout will be totally defined
by shape and data type.
Args:
shape: the shape of data.
dtype: data type.
Note:
Dims of shape should be less than 8. The supported data type defines at
LiteDataType
Examples:
.. code-block:: python
import numpy as np
layout = LiteLayout([1, 4, 8, 8], LiteDataType.LITE_FLOAT)
assert(layout.shape()) == [1, 4, 8, 8]
assert(layout.dtype()) == LiteDataType.LITE_FLOAT
""" """
_fields_ = [ _fields_ = [
...@@ -113,10 +131,14 @@ class _LiteTensorDesc(Structure): ...@@ -113,10 +131,14 @@ class _LiteTensorDesc(Structure):
""" """
warpper of the MegEngine Tensor warpper of the MegEngine Tensor
:is_pinned_host: when set, the storage memory of the tensor is pinned memory, Args:
this is used to Optimize the H2D or D2H memory copy, if the device or layout is_pinned_host: when set, the storage memory of the tensor is pinned
is not set, when copy form other device(CUDA) tensor, this tensor memory. This is used to Optimize the H2D or D2H memory copy, if the
will be automatically set to pinned tensor device or layout is not set, when copy form other device(CUDA)
tensor, this tensor will be automatically set to pinned tensor
layout(LiteLayout): layout of this tensor
device_type: type of device
device_id: id of device
""" """
_fields_ = [ _fields_ = [
...@@ -144,7 +166,7 @@ class _LiteTensorDesc(Structure): ...@@ -144,7 +166,7 @@ class _LiteTensorDesc(Structure):
class _TensorAPI(_LiteCObjBase): class _TensorAPI(_LiteCObjBase):
""" """
get the api from the lib Get the API from the lib
""" """
_api_ = [ _api_ = [
...@@ -183,7 +205,22 @@ class _TensorAPI(_LiteCObjBase): ...@@ -183,7 +205,22 @@ class _TensorAPI(_LiteCObjBase):
class LiteTensor(object): class LiteTensor(object):
""" """
the tensor to hold a block of data Description of a block of data with neccessary information.
Args:
layout: layout of Tensor
device_type: device type of Tensor
device_id: device id of Tensor
is_pinned_host: when set, the storage memory of the tensor is pinned
memory. This is used to Optimize the H2D or D2H memory copy, if the
device or layout is not set, when copy form other device(CUDA)
tensor, this tensor will be automatically set to pinned tensor
shapes: the shape of data
dtype: data type
Note:
Dims of shape should be less than 8. The supported data type defines at
LiteDataType
""" """
_api = _TensorAPI()._lib _api = _TensorAPI()._lib
...@@ -197,10 +234,6 @@ class LiteTensor(object): ...@@ -197,10 +234,6 @@ class LiteTensor(object):
shapes=None, shapes=None,
dtype=None, dtype=None,
): ):
"""
create a Tensor with layout, device, is_pinned_host or shapes, dtype,
device_type, device_id, is_pinned_host param
"""
self._tensor = _Ctensor() self._tensor = _Ctensor()
self._layout = LiteLayout() self._layout = LiteLayout()
if layout is not None: if layout is not None:
...@@ -232,7 +265,11 @@ class LiteTensor(object): ...@@ -232,7 +265,11 @@ class LiteTensor(object):
def share_memory_with(self, src_tensor): def share_memory_with(self, src_tensor):
""" """
share the same memory with the src_tensor, the self memory will be freed share the same memory with the ``src_tensor``, the self memory will be
freed
Args:
src_tensor: the source tensor that will share memory with this tensor
""" """
assert isinstance(src_tensor, LiteTensor) assert isinstance(src_tensor, LiteTensor)
self._api.LITE_tensor_share_memory_with(self._tensor, src_tensor._tensor) self._api.LITE_tensor_share_memory_with(self._tensor, src_tensor._tensor)
...@@ -265,7 +302,7 @@ class LiteTensor(object): ...@@ -265,7 +302,7 @@ class LiteTensor(object):
@property @property
def device_type(self): def device_type(self):
""" """
get device of the tensor get device type of the tensor
""" """
device_type = c_int() device_type = c_int()
self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type)) self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type))
...@@ -320,6 +357,9 @@ class LiteTensor(object): ...@@ -320,6 +357,9 @@ class LiteTensor(object):
def copy_from(self, src_tensor): def copy_from(self, src_tensor):
""" """
copy memory form the src_tensor copy memory form the src_tensor
Args:
src_tensor: source tensor
""" """
assert isinstance(src_tensor, LiteTensor) assert isinstance(src_tensor, LiteTensor)
self._api.LITE_tensor_copy(self._tensor, src_tensor._tensor) self._api.LITE_tensor_copy(self._tensor, src_tensor._tensor)
...@@ -327,8 +367,10 @@ class LiteTensor(object): ...@@ -327,8 +367,10 @@ class LiteTensor(object):
def reshape(self, shape): def reshape(self, shape):
""" """
reshape the tensor with data not change, only change the shape reshape the tensor with data not change.
:param shape: int arrary of dst_shape
Args:
shape: target shape
""" """
shape = list(shape) shape = list(shape)
length = len(shape) length = len(shape)
...@@ -339,9 +381,11 @@ class LiteTensor(object): ...@@ -339,9 +381,11 @@ class LiteTensor(object):
def slice(self, start, end, step=None): def slice(self, start, end, step=None):
""" """
slice the tensor with gaven start, end, step slice the tensor with gaven start, end, step
:param start: silce begin index of each dim
:param end: silce end index of each dim Args:
:param step: silce step of each dim start: silce begin index of each dim
end: silce end index of each dim
step: silce step of each dim
""" """
start = list(start) start = list(start)
end = list(end) end = list(end)
...@@ -357,7 +401,7 @@ class LiteTensor(object): ...@@ -357,7 +401,7 @@ class LiteTensor(object):
c_step = (c_size_t * length)(*step) c_step = (c_size_t * length)(*step)
slice_tensor = LiteTensor() slice_tensor = LiteTensor()
self._api.LITE_tensor_slice( self._api.LITE_tensor_slice(
self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor) self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor),
) )
slice_tensor.update() slice_tensor.update()
return slice_tensor return slice_tensor
...@@ -373,7 +417,9 @@ class LiteTensor(object): ...@@ -373,7 +417,9 @@ class LiteTensor(object):
def set_data_by_share(self, data, length=0, layout=None): def set_data_by_share(self, data, length=0, layout=None):
""" """
share the data to the tensor share the data to the tensor
param data: the data will shared to the tensor, it should be a
Args:
data: the data will shared to the tensor, it should be a
numpy.ndarray or ctypes data numpy.ndarray or ctypes data
""" """
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
...@@ -400,9 +446,13 @@ class LiteTensor(object): ...@@ -400,9 +446,13 @@ class LiteTensor(object):
def set_data_by_copy(self, data, data_length=0, layout=None): def set_data_by_copy(self, data, data_length=0, layout=None):
""" """
copy the data to the tensor, the memory of the tensor must be continue copy the data to the tensor
param data: the data to copy to tensor, it should be list,
numpy.ndarraya or ctypes with length Args:
data: the data to copy to tensor, it should be list, numpy.ndarraya
or ctypes with length
data_length: length of data in bytes
layout: layout of data
""" """
if layout is not None: if layout is not None:
self.layout = layout self.layout = layout
...@@ -440,8 +490,11 @@ class LiteTensor(object): ...@@ -440,8 +490,11 @@ class LiteTensor(object):
def get_data_by_share(self): def get_data_by_share(self):
""" """
get the data in the tensor, add share the data with a new numpy, and get the data in the tensor, add share the data with a new numpy, and
return the numpy arrray, be careful, the data in numpy is valid before return the numpy arrray
the tensor memory is write again, such as LiteNetwok forward next time.
Note:
Be careful, the data in numpy is valid before the tensor memory is
write again, such as LiteNetwok forward next time.
""" """
self.update() self.update()
...@@ -490,7 +543,9 @@ def LiteTensorConcat( ...@@ -490,7 +543,9 @@ def LiteTensorConcat(
tensors, dim, device_type=LiteDeviceType.LITE_DEVICE_DEFAULT, device_id=-1 tensors, dim, device_type=LiteDeviceType.LITE_DEVICE_DEFAULT, device_id=-1
): ):
""" """
concat tensor in input dim to one tensor concat tensors at expected dim to one tensor
Args:
dim : the dim to act concat dim : the dim to act concat
device_type: the result tensor device type device_type: the result tensor device type
device_id: the result tensor device id device_id: the result tensor device id
...@@ -515,6 +570,9 @@ def LiteTensorConcat( ...@@ -515,6 +570,9 @@ def LiteTensorConcat(
def lite_dtype_2_numpy(dtype): def lite_dtype_2_numpy(dtype):
""" """
convert lite dtype to corresponding numpy dtype convert lite dtype to corresponding numpy dtype
Args:
dtype(LiteDataType): source dtype
""" """
assert isinstance( assert isinstance(
dtype, LiteDataType dtype, LiteDataType
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册