Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b84d2893
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b84d2893
编写于
10月 13, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(lite): fix lite c thread local
GitOrigin-RevId: 36d2da7d68a71253904851b32daff61ed03738ce
上级
936bb237
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
40 addition
and
10 deletion
+40
-10
lite/lite-c/src/common.h
lite/lite-c/src/common.h
+2
-0
lite/lite-c/src/global.cpp
lite/lite-c/src/global.cpp
+5
-5
lite/lite-c/src/network.cpp
lite/lite-c/src/network.cpp
+6
-2
lite/lite-c/src/tensor.cpp
lite/lite-c/src/tensor.cpp
+9
-3
lite/test/test_tensor_c.cpp
lite/test/test_tensor_c.cpp
+18
-0
未找到文件。
lite/lite-c/src/common.h
浏览文件 @
b84d2893
...
@@ -17,8 +17,10 @@
...
@@ -17,8 +17,10 @@
#include "lite-c/tensor_c.h"
#include "lite-c/tensor_c.h"
#include "lite/network.h"
#include "lite/network.h"
#if LITE_ENABLE_EXCEPTION
#include <exception>
#include <exception>
#include <stdexcept>
#include <stdexcept>
#endif
//! convert c Layout to lite::Layout
//! convert c Layout to lite::Layout
lite
::
Layout
convert_to_layout
(
const
LiteLayout
&
layout
);
lite
::
Layout
convert_to_layout
(
const
LiteLayout
&
layout
);
...
...
lite/lite-c/src/global.cpp
浏览文件 @
b84d2893
...
@@ -13,11 +13,7 @@
...
@@ -13,11 +13,7 @@
#include "common.h"
#include "common.h"
#include "lite-c/global_c.h"
#include "lite-c/global_c.h"
#include <exception>
#include <mutex>
namespace
{
namespace
{
class
ErrorMsg
{
class
ErrorMsg
{
public:
public:
std
::
string
&
get_error_msg
()
{
return
error_msg
;
}
std
::
string
&
get_error_msg
()
{
return
error_msg
;
}
...
@@ -26,18 +22,22 @@ public:
...
@@ -26,18 +22,22 @@ public:
private:
private:
std
::
string
error_msg
;
std
::
string
error_msg
;
};
};
static
LITE_MUTEX
mtx_error
;
ErrorMsg
&
get_global_error
()
{
ErrorMsg
&
get_global_error
()
{
static
thread_local
ErrorMsg
error_msg
;
static
ErrorMsg
error_msg
;
return
error_msg
;
return
error_msg
;
}
}
}
// namespace
}
// namespace
int
LiteHandleException
(
const
std
::
exception
&
e
)
{
int
LiteHandleException
(
const
std
::
exception
&
e
)
{
LITE_LOCK_GUARD
(
mtx_error
);
get_global_error
().
set_error_msg
(
e
.
what
());
get_global_error
().
set_error_msg
(
e
.
what
());
return
-
1
;
return
-
1
;
}
}
const
char
*
LITE_get_last_error
()
{
const
char
*
LITE_get_last_error
()
{
LITE_LOCK_GUARD
(
mtx_error
);
return
get_global_error
().
get_error_msg
().
c_str
();
return
get_global_error
().
get_error_msg
().
c_str
();
}
}
...
...
lite/lite-c/src/network.cpp
浏览文件 @
b84d2893
...
@@ -72,9 +72,9 @@ LiteNetworkIO* default_network_io() {
...
@@ -72,9 +72,9 @@ LiteNetworkIO* default_network_io() {
}
}
namespace
{
namespace
{
static
LITE_MUTEX
mtx_network
;
std
::
unordered_map
<
void
*
,
std
::
shared_ptr
<
lite
::
Network
>>&
get_gloabl_network_holder
()
{
std
::
unordered_map
<
void
*
,
std
::
shared_ptr
<
lite
::
Network
>>&
get_gloabl_network_holder
()
{
static
thread_local
std
::
unordered_map
<
void
*
,
std
::
shared_ptr
<
lite
::
Network
>>
static
std
::
unordered_map
<
void
*
,
std
::
shared_ptr
<
lite
::
Network
>>
network_holder
;
network_holder
;
return
network_holder
;
return
network_holder
;
}
}
...
@@ -168,6 +168,7 @@ int LITE_make_default_network(LiteNetwork* network) {
...
@@ -168,6 +168,7 @@ int LITE_make_default_network(LiteNetwork* network) {
LITE_CAPI_BEGIN
();
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
auto
lite_network
=
std
::
make_shared
<
lite
::
Network
>
();
auto
lite_network
=
std
::
make_shared
<
lite
::
Network
>
();
LITE_LOCK_GUARD
(
mtx_network
);
get_gloabl_network_holder
()[
lite_network
.
get
()]
=
lite_network
;
get_gloabl_network_holder
()[
lite_network
.
get
()]
=
lite_network
;
*
network
=
lite_network
.
get
();
*
network
=
lite_network
.
get
();
LITE_CAPI_END
();
LITE_CAPI_END
();
...
@@ -179,6 +180,7 @@ int LITE_make_network(
...
@@ -179,6 +180,7 @@ int LITE_make_network(
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
auto
lite_network
=
std
::
make_shared
<
lite
::
Network
>
(
auto
lite_network
=
std
::
make_shared
<
lite
::
Network
>
(
convert_to_lite_config
(
config
),
convert_to_lite_io
(
network_io
));
convert_to_lite_config
(
config
),
convert_to_lite_io
(
network_io
));
LITE_LOCK_GUARD
(
mtx_network
);
get_gloabl_network_holder
()[
lite_network
.
get
()]
=
lite_network
;
get_gloabl_network_holder
()[
lite_network
.
get
()]
=
lite_network
;
*
network
=
lite_network
.
get
();
*
network
=
lite_network
.
get
();
LITE_CAPI_END
();
LITE_CAPI_END
();
...
@@ -188,6 +190,7 @@ int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) {
...
@@ -188,6 +190,7 @@ int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) {
LITE_CAPI_BEGIN
();
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
auto
lite_network
=
std
::
make_shared
<
lite
::
Network
>
(
convert_to_lite_config
(
config
));
auto
lite_network
=
std
::
make_shared
<
lite
::
Network
>
(
convert_to_lite_config
(
config
));
LITE_LOCK_GUARD
(
mtx_network
);
get_gloabl_network_holder
()[
lite_network
.
get
()]
=
lite_network
;
get_gloabl_network_holder
()[
lite_network
.
get
()]
=
lite_network
;
*
network
=
lite_network
.
get
();
*
network
=
lite_network
.
get
();
LITE_CAPI_END
();
LITE_CAPI_END
();
...
@@ -212,6 +215,7 @@ int LITE_load_model_from_path(LiteNetwork network, const char* model_path) {
...
@@ -212,6 +215,7 @@ int LITE_load_model_from_path(LiteNetwork network, const char* model_path) {
int
LITE_destroy_network
(
LiteNetwork
network
)
{
int
LITE_destroy_network
(
LiteNetwork
network
)
{
LITE_CAPI_BEGIN
();
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
LITE_LOCK_GUARD
(
mtx_network
);
get_gloabl_network_holder
().
erase
(
network
);
get_gloabl_network_holder
().
erase
(
network
);
LITE_CAPI_END
();
LITE_CAPI_END
();
}
}
...
...
lite/lite-c/src/tensor.cpp
浏览文件 @
b84d2893
...
@@ -26,13 +26,16 @@ const LiteTensorDesc default_desc = {
...
@@ -26,13 +26,16 @@ const LiteTensorDesc default_desc = {
.
device_type
=
LiteDeviceType
::
LITE_CPU
,
.
device_type
=
LiteDeviceType
::
LITE_CPU
,
.
device_id
=
0
};
.
device_id
=
0
};
namespace
{
namespace
{
static
LITE_MUTEX
mtx_tensor
;
std
::
unordered_map
<
void
*
,
std
::
shared_ptr
<
lite
::
Tensor
>>&
get_global_tensor_holder
()
{
std
::
unordered_map
<
void
*
,
std
::
shared_ptr
<
lite
::
Tensor
>>&
get_global_tensor_holder
()
{
static
thread_local
std
::
unordered_map
<
void
*
,
std
::
shared_ptr
<
lite
::
Tensor
>>
static
std
::
unordered_map
<
void
*
,
std
::
shared_ptr
<
lite
::
Tensor
>>
global_holder
;
global_holder
;
return
global_holder
;
return
global_holder
;
}
}
static
LITE_MUTEX
mtx_attr
;
std
::
unordered_map
<
std
::
string
,
lite
::
LiteAny
>&
get_global_tensor_attr_holder
()
{
std
::
unordered_map
<
std
::
string
,
lite
::
LiteAny
>&
get_global_tensor_attr_holder
()
{
static
thread_local
std
::
unordered_map
<
std
::
string
,
lite
::
LiteAny
>
global_holder
;
static
std
::
unordered_map
<
std
::
string
,
lite
::
LiteAny
>
global_holder
;
return
global_holder
;
return
global_holder
;
}
}
}
// namespace
}
// namespace
...
@@ -68,6 +71,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
...
@@ -68,6 +71,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
auto
lite_tensor
=
std
::
make_shared
<
lite
::
Tensor
>
(
auto
lite_tensor
=
std
::
make_shared
<
lite
::
Tensor
>
(
tensor_describe
.
device_id
,
tensor_describe
.
device_type
,
layout
,
tensor_describe
.
device_id
,
tensor_describe
.
device_type
,
layout
,
tensor_describe
.
is_pinned_host
);
tensor_describe
.
is_pinned_host
);
LITE_LOCK_GUARD
(
mtx_tensor
);
get_global_tensor_holder
()[
lite_tensor
.
get
()]
=
lite_tensor
;
get_global_tensor_holder
()[
lite_tensor
.
get
()]
=
lite_tensor
;
*
tensor
=
lite_tensor
.
get
();
*
tensor
=
lite_tensor
.
get
();
LITE_CAPI_END
();
LITE_CAPI_END
();
...
@@ -76,6 +80,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
...
@@ -76,6 +80,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
int
LITE_destroy_tensor
(
LiteTensor
tensor
)
{
int
LITE_destroy_tensor
(
LiteTensor
tensor
)
{
LITE_CAPI_BEGIN
();
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
tensor
,
"The tensor pass to LITE c_api is null"
);
LITE_ASSERT
(
tensor
,
"The tensor pass to LITE c_api is null"
);
LITE_LOCK_GUARD
(
mtx_tensor
);
get_global_tensor_holder
().
erase
(
tensor
);
get_global_tensor_holder
().
erase
(
tensor
);
LITE_CAPI_END
();
LITE_CAPI_END
();
}
}
...
@@ -132,6 +137,7 @@ int LITE_tensor_slice(
...
@@ -132,6 +137,7 @@ int LITE_tensor_slice(
}
}
}
}
auto
ret_tensor
=
static_cast
<
lite
::
Tensor
*>
(
tensor
)
->
slice
(
starts
,
ends
,
steps
);
auto
ret_tensor
=
static_cast
<
lite
::
Tensor
*>
(
tensor
)
->
slice
(
starts
,
ends
,
steps
);
LITE_LOCK_GUARD
(
mtx_tensor
);
get_global_tensor_holder
()[
ret_tensor
.
get
()]
=
ret_tensor
;
get_global_tensor_holder
()[
ret_tensor
.
get
()]
=
ret_tensor
;
*
slice_tensor
=
ret_tensor
.
get
();
*
slice_tensor
=
ret_tensor
.
get
();
LITE_CAPI_END
();
LITE_CAPI_END
();
...
...
lite/test/test_tensor_c.cpp
浏览文件 @
b84d2893
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <memory>
#include <memory>
#include <thread>
TEST
(
TestCapiTensor
,
Basic
)
{
TEST
(
TestCapiTensor
,
Basic
)
{
LiteTensor
c_tensor0
,
c_tensor1
;
LiteTensor
c_tensor0
,
c_tensor1
;
...
@@ -305,6 +306,23 @@ TEST(TestCapiTensor, GetMemoryByIndex) {
...
@@ -305,6 +306,23 @@ TEST(TestCapiTensor, GetMemoryByIndex) {
LITE_destroy_tensor
(
c_tensor0
);
LITE_destroy_tensor
(
c_tensor0
);
}
}
TEST
(
TestCapiTensor
,
ThreadLocalError
)
{
LiteTensor
c_tensor0
;
LiteTensorDesc
description
=
default_desc
;
description
.
layout
=
LiteLayout
{{
20
,
20
},
2
,
LiteDataType
::
LITE_FLOAT
};
void
*
ptr0
,
*
ptr1
;
std
::
thread
thread1
([
&
]()
{
LITE_make_tensor
(
description
,
&
c_tensor0
);
LITE_get_tensor_memory
(
c_tensor0
,
&
ptr0
);
});
thread1
.
join
();
std
::
thread
thread2
([
&
]()
{
LITE_get_tensor_memory
(
c_tensor0
,
&
ptr1
);
LITE_destroy_tensor
(
c_tensor0
);
});
thread2
.
join
();
}
#endif
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录