Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a2a46b56
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a2a46b56
编写于
10月 13, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(lite): fix rknn error in lite
GitOrigin-RevId: b66aa1bf73af8c2993c66f52cc45b991a102d0fa
上级
849f0ece
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
59 addition
and
20 deletion
+59
-20
lite/include/lite/global.h
lite/include/lite/global.h
+1
-1
lite/include/lite/tensor.h
lite/include/lite/tensor.h
+31
-11
lite/lite-c/src/global.cpp
lite/lite-c/src/global.cpp
+2
-1
lite/src/mge/network_impl.cpp
lite/src/mge/network_impl.cpp
+5
-5
lite/src/parse_info/default_parse.h
lite/src/parse_info/default_parse.h
+1
-1
lite/src/tensor.cpp
lite/src/tensor.cpp
+19
-1
未找到文件。
lite/include/lite/global.h
浏览文件 @
a2a46b56
...
...
@@ -77,7 +77,7 @@ LITE_API bool update_decryption_or_key(
* other config not inclue in config and networkIO, ParseInfoFunc can fill it
* with the information in json, now support:
* "device_id" : int, default 0
* "number_threads" :
size
_t, default 1
* "number_threads" :
uint32
_t, default 1
* "is_inplace_model" : bool, default false
* "use_tensorrt" : bool, default false
*/
...
...
lite/include/lite/tensor.h
浏览文件 @
a2a46b56
...
...
@@ -149,28 +149,42 @@ private:
*/
class
LITE_API
LiteAny
{
public:
enum
Type
{
STRING
=
0
,
INT32
=
1
,
UINT32
=
2
,
UINT8
=
3
,
INT8
=
4
,
INT64
=
5
,
UINT64
=
6
,
BOOL
=
7
,
VOID_PTR
=
8
,
FLOAT
=
9
,
NONE_SUPPORT
=
10
,
};
LiteAny
()
=
default
;
template
<
class
T
>
LiteAny
(
T
value
)
:
m_holder
(
new
AnyHolder
<
T
>
(
value
))
{
m_
is_string
=
std
::
is_same
<
std
::
string
,
T
>
();
m_
type
=
get_type
<
T
>
();
}
LiteAny
(
const
LiteAny
&
any
)
{
m_holder
=
any
.
m_holder
->
clone
();
m_
is_string
=
any
.
is_string
()
;
m_
type
=
any
.
m_type
;
}
LiteAny
&
operator
=
(
const
LiteAny
&
any
)
{
m_holder
=
any
.
m_holder
->
clone
();
m_
is_string
=
any
.
is_string
()
;
m_
type
=
any
.
m_type
;
return
*
this
;
}
bool
is_string
()
const
{
return
m_is_string
;
}
template
<
class
T
>
Type
get_type
()
const
;
class
HolderBase
{
public:
virtual
~
HolderBase
()
=
default
;
virtual
std
::
shared_ptr
<
HolderBase
>
clone
()
=
0
;
virtual
size_t
type_length
()
const
=
0
;
};
template
<
class
T
>
...
...
@@ -180,7 +194,6 @@ public:
virtual
std
::
shared_ptr
<
HolderBase
>
clone
()
override
{
return
std
::
make_shared
<
AnyHolder
>
(
m_value
);
}
virtual
size_t
type_length
()
const
override
{
return
sizeof
(
T
);
}
public:
T
m_value
;
...
...
@@ -188,14 +201,21 @@ public:
//! if type is miss matching, it will throw
void
type_missmatch
(
size_t
expect
,
size_t
get
)
const
;
//! only check the storage type and the visit type length, so it's not safe
template
<
class
T
>
T
un
safe_cast
()
const
{
if
(
sizeof
(
T
)
!=
m_holder
->
type_length
()
)
{
type_missmatch
(
m_
holder
->
type_length
(),
sizeof
(
T
));
T
safe_cast
()
const
{
if
(
get_type
<
T
>
()
!=
m_type
)
{
type_missmatch
(
m_
type
,
get_type
<
T
>
(
));
}
return
static_cast
<
LiteAny
::
AnyHolder
<
T
>*>
(
m_holder
.
get
())
->
m_value
;
}
template
<
class
T
>
bool
try_cast
()
const
{
if
(
get_type
<
T
>
()
==
m_type
)
{
return
true
;
}
else
{
return
false
;
}
}
//! only check the storage type and the visit type length, so it's not safe
void
*
cast_void_ptr
()
const
{
return
&
static_cast
<
LiteAny
::
AnyHolder
<
char
>*>
(
m_holder
.
get
())
->
m_value
;
...
...
@@ -203,7 +223,7 @@ public:
private:
std
::
shared_ptr
<
HolderBase
>
m_holder
;
bool
m_is_string
=
false
;
Type
m_type
=
NONE_SUPPORT
;
};
/*********************** special tensor function ***************/
...
...
lite/lite-c/src/global.cpp
浏览文件 @
a2a46b56
...
...
@@ -127,7 +127,8 @@ int LITE_register_parse_info_func(
separate_config_map
[
"device_id"
]
=
device_id
;
}
if
(
nr_threads
!=
1
)
{
separate_config_map
[
"nr_threads"
]
=
nr_threads
;
separate_config_map
[
"nr_threads"
]
=
static_cast
<
uint32_t
>
(
nr_threads
);
}
if
(
is_cpu_inplace_mode
!=
false
)
{
separate_config_map
[
"is_inplace_mode"
]
=
is_cpu_inplace_mode
;
...
...
lite/src/mge/network_impl.cpp
浏览文件 @
a2a46b56
...
...
@@ -352,19 +352,19 @@ void NetworkImplDft::load_model(
//! config some flag get from json config file
if
(
separate_config_map
.
find
(
"device_id"
)
!=
separate_config_map
.
end
())
{
set_device_id
(
separate_config_map
[
"device_id"
].
un
safe_cast
<
int
>
());
set_device_id
(
separate_config_map
[
"device_id"
].
safe_cast
<
int
>
());
}
if
(
separate_config_map
.
find
(
"number_threads"
)
!=
separate_config_map
.
end
()
&&
separate_config_map
[
"number_threads"
].
unsafe_cast
<
size
_t
>
()
>
1
)
{
separate_config_map
[
"number_threads"
].
safe_cast
<
uint32
_t
>
()
>
1
)
{
set_cpu_threads_number
(
separate_config_map
[
"number_threads"
].
unsafe_cast
<
size
_t
>
());
separate_config_map
[
"number_threads"
].
safe_cast
<
uint32
_t
>
());
}
if
(
separate_config_map
.
find
(
"enable_inplace_model"
)
!=
separate_config_map
.
end
()
&&
separate_config_map
[
"enable_inplace_model"
].
un
safe_cast
<
bool
>
())
{
separate_config_map
[
"enable_inplace_model"
].
safe_cast
<
bool
>
())
{
set_cpu_inplace_mode
();
}
if
(
separate_config_map
.
find
(
"use_tensorrt"
)
!=
separate_config_map
.
end
()
&&
separate_config_map
[
"use_tensorrt"
].
un
safe_cast
<
bool
>
())
{
separate_config_map
[
"use_tensorrt"
].
safe_cast
<
bool
>
())
{
use_tensorrt
();
}
...
...
lite/src/parse_info/default_parse.h
浏览文件 @
a2a46b56
...
...
@@ -84,7 +84,7 @@ bool default_parse_info(
}
if
(
device_json
.
contains
(
"number_threads"
))
{
separate_config_map
[
"number_threads"
]
=
static_cast
<
size
_t
>
(
device_json
[
"number_threads"
]);
static_cast
<
uint32
_t
>
(
device_json
[
"number_threads"
]);
}
if
(
device_json
.
contains
(
"enable_inplace_model"
))
{
separate_config_map
[
"enable_inplace_model"
]
=
...
...
lite/src/tensor.cpp
浏览文件 @
a2a46b56
...
...
@@ -277,10 +277,28 @@ void Tensor::update_from_implement() {
void
LiteAny
::
type_missmatch
(
size_t
expect
,
size_t
get
)
const
{
LITE_THROW
(
ssprintf
(
"The type store in LiteAny is not match the visit type, type of "
"storage
length is %zu, type of visit length
is %zu."
,
"storage
enum is %zu, type of visit enum
is %zu."
,
expect
,
get
));
}
namespace
lite
{
#define GET_TYPE(ctype, ENUM) \
template <> \
LiteAny::Type LiteAny::get_type<ctype>() const { \
return ENUM; \
}
GET_TYPE
(
std
::
string
,
STRING
)
GET_TYPE
(
int32_t
,
INT32
)
GET_TYPE
(
uint32_t
,
UINT32
)
GET_TYPE
(
int8_t
,
INT8
)
GET_TYPE
(
uint8_t
,
UINT8
)
GET_TYPE
(
int64_t
,
INT64
)
GET_TYPE
(
uint64_t
,
UINT64
)
GET_TYPE
(
float
,
FLOAT
)
GET_TYPE
(
bool
,
BOOL
)
GET_TYPE
(
void
*
,
VOID_PTR
)
}
// namespace lite
std
::
shared_ptr
<
Tensor
>
TensorUtils
::
concat
(
const
std
::
vector
<
Tensor
>&
tensors
,
int
dim
,
LiteDeviceType
dst_device
,
int
dst_device_id
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录