Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ab6328c5
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
ab6328c5
编写于
10月 29, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): port persistent cache
GitOrigin-RevId: 8ca24a37cc28a0be3f659e0e8863fee1beac3a38
上级
60c6d59f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
172 addition
and
0 deletion
+172
-0
imperative/python/megengine/__init__.py
imperative/python/megengine/__init__.py
+5
-0
imperative/python/megengine/utils/persistent_cache.py
imperative/python/megengine/utils/persistent_cache.py
+90
-0
imperative/python/src/helper.h
imperative/python/src/helper.h
+44
-0
imperative/python/src/utils.cpp
imperative/python/src/utils.cpp
+17
-0
imperative/python/test/unit/test_utils.py
imperative/python/test/unit/test_utils.py
+16
-0
未找到文件。
imperative/python/megengine/__init__.py
浏览文件 @
ab6328c5
...
...
@@ -78,13 +78,18 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from
.serialization
import
load
,
save
from
.tensor
import
Parameter
,
Tensor
,
tensor
from
.version
import
__version__
from
.utils
import
persistent_cache
,
comp_graph_tools
as
cgtools
_set_fork_exec_path_for_timed_func
(
sys
.
executable
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"utils"
,
"_timed_func_fork_exec_entry.py"
),
)
_persistent_cache_impl_ins
=
persistent_cache
.
PersistentCacheOnServer
()
_persistent_cache_impl_ins
.
reg
()
atexit
.
register
(
sync
)
del
sync
del
_set_fork_exec_path_for_timed_func
del
_persistent_cache_impl_ins
imperative/python/megengine/utils/persistent_cache.py
0 → 100644
浏览文件 @
ab6328c5
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
argparse
import
getpass
import
json
import
os
import
shelve
from
..core._imperative_rt
import
PersistentCache
as
_PersistentCache
from
..logger
import
get_logger
from
..version
import
__version__
class
_FakeRedisConn
:
def
__init__
(
self
):
try
:
from
..hub.hub
import
_get_megengine_home
cache_dir
=
os
.
path
.
expanduser
(
os
.
path
.
join
(
_get_megengine_home
(),
"persistent_cache"
)
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
cache_file
=
os
.
path
.
join
(
cache_dir
,
"cache"
)
self
.
_dict
=
shelve
.
open
(
cache_file
)
self
.
_is_shelve
=
True
except
:
self
.
_dict
=
{}
self
.
_is_shelve
=
False
def
get
(
self
,
key
):
if
self
.
_is_shelve
and
isinstance
(
key
,
bytes
):
key
=
key
.
decode
(
"utf-8"
)
return
self
.
_dict
.
get
(
key
)
def
set
(
self
,
key
,
val
):
if
self
.
_is_shelve
and
isinstance
(
key
,
bytes
):
key
=
key
.
decode
(
"utf-8"
)
self
.
_dict
[
key
]
=
val
def
__del__
(
self
):
if
self
.
_is_shelve
:
self
.
_dict
.
close
()
class
PersistentCacheOnServer
(
_PersistentCache
):
_cached_conn
=
None
_prefix
=
None
_prev_get_refkeep
=
None
@
property
def
_conn
(
self
):
"""get redis connection"""
if
self
.
_cached_conn
is
None
:
self
.
_cached_conn
=
_FakeRedisConn
()
self
.
_prefix
=
self
.
make_user_prefix
()
return
self
.
_cached_conn
@
classmethod
def
make_user_prefix
(
cls
):
return
"mgbcache:{}"
.
format
(
getpass
.
getuser
())
def
_make_key
(
self
,
category
,
key
):
prefix_with_version
=
"{}:MGB{}"
.
format
(
self
.
_prefix
,
__version__
)
return
b
"@"
.
join
(
(
prefix_with_version
.
encode
(
"ascii"
),
category
.
encode
(
"ascii"
),
key
)
)
def
put
(
self
,
category
,
key
,
value
):
conn
=
self
.
_conn
key
=
self
.
_make_key
(
category
,
key
)
conn
.
set
(
key
,
value
)
def
get
(
self
,
category
,
key
):
conn
=
self
.
_conn
key
=
self
.
_make_key
(
category
,
key
)
self
.
_prev_get_refkeep
=
conn
.
get
(
key
)
return
self
.
_prev_get_refkeep
imperative/python/src/helper.h
浏览文件 @
ab6328c5
...
...
@@ -12,6 +12,7 @@
#pragma once
#include "megbrain/graph.h"
#include "megbrain/utils/persistent_cache.h"
#include <Python.h>
#include <string>
...
...
@@ -328,6 +329,49 @@ namespace detail {
template
<
>
struct
type_caster
<
mgb
::
CompNode
>
:
public
from_none_caster
<
mgb
::
CompNode
>
{};
template
<
>
struct
type_caster
<
mgb
::
PersistentCache
::
Blob
>
{
PYBIND11_TYPE_CASTER
(
mgb
::
PersistentCache
::
Blob
,
_
(
"Blob"
));
public:
bool
load
(
handle
src
,
bool
convert
)
{
if
(
!
isinstance
<
bytes
>
(
src
))
{
return
false
;
}
value
.
ptr
=
PYBIND11_BYTES_AS_STRING
(
src
.
ptr
());
value
.
size
=
PYBIND11_BYTES_SIZE
(
src
.
ptr
());
return
true
;
}
static
handle
cast
(
mgb
::
PersistentCache
::
Blob
blob
,
return_value_policy
/* policy */
,
handle
/* parent */
)
{
return
bytes
((
const
char
*
)
blob
.
ptr
,
blob
.
size
);
}
};
template
<
typename
T
>
struct
type_caster
<
mgb
::
Maybe
<
T
>>
{
using
value_conv
=
make_caster
<
T
>
;
PYBIND11_TYPE_CASTER
(
mgb
::
Maybe
<
T
>
,
_
(
"Optional["
)
+
value_conv
::
name
+
_
(
"]"
));
public:
bool
load
(
handle
src
,
bool
convert
)
{
if
(
!
src
)
{
return
false
;
}
if
(
src
.
is_none
())
{
return
true
;
}
value_conv
inner_caster
;
if
(
!
inner_caster
.
load
(
src
,
convert
))
{
return
false
;
}
value
.
emplace
(
cast_op
<
T
&&>
(
std
::
move
(
inner_caster
)));
return
true
;
}
static
handle
cast
(
mgb
::
Maybe
<
T
>
src
,
return_value_policy
policy
,
handle
parent
)
{
if
(
!
src
.
valid
())
{
return
none
().
inc_ref
();
}
return
pybind11
::
cast
(
src
.
val
(),
policy
,
parent
);
}
};
}
// detail
}
// PYBIND11_NAMESPACE
...
...
imperative/python/src/utils.cpp
浏览文件 @
ab6328c5
...
...
@@ -25,6 +25,7 @@
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/tensor_sanity_check.h"
#include "megbrain/serialization/helper.h"
#include "megbrain/utils/persistent_cache.h"
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/mm_handler.h"
...
...
@@ -262,4 +263,20 @@ void init_utils(py::module m) {
m
.
def
(
"_timed_func_exec_cb"
,
[](
const
std
::
string
&
user_data
){
mgb
::
sys
::
TimedFuncInvoker
::
ins
().
fork_exec_impl_mainloop
(
user_data
.
c_str
());
});
using
mgb
::
PersistentCache
;
class
PyPersistentCache
:
public
mgb
::
PersistentCache
{
public:
mgb
::
Maybe
<
Blob
>
get
(
const
std
::
string
&
category
,
const
Blob
&
key
)
override
{
PYBIND11_OVERLOAD_PURE
(
mgb
::
Maybe
<
Blob
>
,
PersistentCache
,
get
,
category
,
key
);
}
void
put
(
const
std
::
string
&
category
,
const
Blob
&
key
,
const
Blob
&
value
)
override
{
PYBIND11_OVERLOAD_PURE
(
void
,
PersistentCache
,
put
,
category
,
key
,
value
);
}
};
py
::
class_
<
PersistentCache
,
PyPersistentCache
,
std
::
shared_ptr
<
PersistentCache
>>
(
m
,
"PersistentCache"
)
.
def
(
py
::
init
<>
())
.
def
(
"get"
,
&
PersistentCache
::
get
)
.
def
(
"put"
,
&
PersistentCache
::
put
)
.
def
(
"reg"
,
&
PersistentCache
::
set_impl
);
}
imperative/python/test/unit/test_utils.py
0 → 100644
浏览文件 @
ab6328c5
import
pytest
import
megengine
from
megengine.utils.persistent_cache
import
PersistentCacheOnServer
def
test_persistent_cache
():
pc
=
PersistentCacheOnServer
()
k0
=
b
"
\x00\x00
"
k1
=
b
"
\x00\x01
"
cat
=
"test"
pc
.
put
(
cat
,
k0
,
k1
)
pc
.
put
(
cat
,
k1
,
k0
)
assert
k1
==
pc
.
get
(
cat
,
k0
)
assert
k0
==
pc
.
get
(
cat
,
k1
)
assert
pc
.
get
(
"test1"
,
k0
)
==
None
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录