Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1657b8e8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1657b8e8
编写于
12月 30, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(fastrun): fix persistent_cache in redis
GitOrigin-RevId: ada5862b057dd7310e63a535874b00da882b21ba
上级
a404cd7d
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
288 addition
and
149 deletion
+288
-149
imperative/python/megengine/__init__.py
imperative/python/megengine/__init__.py
+9
-6
imperative/python/megengine/utils/persistent_cache.py
imperative/python/megengine/utils/persistent_cache.py
+86
-59
imperative/python/requires.txt
imperative/python/requires.txt
+2
-2
imperative/python/src/utils.cpp
imperative/python/src/utils.cpp
+91
-34
imperative/src/impl/ops/collective_comm.cpp
imperative/src/impl/ops/collective_comm.cpp
+1
-1
imperative/src/impl/ops/io_remote.cpp
imperative/src/impl/ops/io_remote.cpp
+2
-2
imperative/src/impl/persistent_cache.cpp
imperative/src/impl/persistent_cache.cpp
+81
-33
imperative/src/include/megbrain/imperative/persistent_cache.h
...rative/src/include/megbrain/imperative/persistent_cache.h
+5
-5
imperative/src/test/collective_comm.cpp
imperative/src/test/collective_comm.cpp
+1
-1
imperative/src/test/io_remote.cpp
imperative/src/test/io_remote.cpp
+1
-1
src/opr-mm/impl/mm_handler.cpp
src/opr-mm/impl/mm_handler.cpp
+4
-1
src/opr-mm/include/megbrain/opr/mm_handler.h
src/opr-mm/include/megbrain/opr/mm_handler.h
+5
-2
src/version.ld
src/version.ld
+0
-2
未找到文件。
imperative/python/megengine/__init__.py
浏览文件 @
1657b8e8
...
...
@@ -84,7 +84,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from
.serialization
import
load
,
save
from
.tensor
import
Parameter
,
Tensor
,
tensor
from
.utils
import
comp_graph_tools
as
cgtools
from
.utils
import
persistent_cache
from
.utils
.persistent_cache
import
PersistentCacheOnServer
as
_PersistentCacheOnServer
from
.version
import
__version__
_set_fork_exec_path_for_timed_func
(
...
...
@@ -92,15 +92,13 @@ _set_fork_exec_path_for_timed_func(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"utils"
,
"_timed_func_fork_exec_entry.py"
),
)
atexit
.
register
(
_close
)
del
_set_fork_exec_path_for_timed_func
_exit_handlers
=
[]
def
_run_exit_handlers
():
for
handler
in
_exit_handlers
:
for
handler
in
reversed
(
_exit_handlers
)
:
handler
()
_exit_handlers
.
clear
()
...
...
@@ -117,6 +115,13 @@ def _atexit(handler):
_exit_handlers
.
append
(
handler
)
_atexit
(
_close
)
_persistent_cache
=
_PersistentCacheOnServer
()
_persistent_cache
.
reg
()
_atexit
(
_persistent_cache
.
flush
)
# subpackages
import
megengine.amp
import
megengine.autodiff
...
...
@@ -132,5 +137,3 @@ import megengine.quantization
import
megengine.random
import
megengine.utils
import
megengine.traced_module
persistent_cache
.
get_manager
()
imperative/python/megengine/utils/persistent_cache.py
浏览文件 @
1657b8e8
...
...
@@ -8,87 +8,114 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
argparse
import
contextlib
import
getpass
import
os
import
sys
import
urllib.parse
from
..core._imperative_rt
import
PersistentCacheManager
as
_PersistentCacheManager
import
filelock
from
..core._imperative_rt
import
PersistentCache
as
_PersistentCache
from
..logger
import
get_logger
from
..version
import
__version__
,
git_version
class
PersistentCache
Manager
(
_PersistentCacheManager
):
class
PersistentCache
OnServer
(
_PersistentCache
):
def
__init__
(
self
):
super
().
__init__
()
if
os
.
getenv
(
"MGE_FASTRUN_CACHE_TYPE"
)
==
"MEMORY"
:
get_logger
().
info
(
"fastrun use in-memory cache"
)
self
.
open_memory
()
elif
os
.
getenv
(
"MGE_FASTRUN_CACHE_TYPE"
)
==
"FILE"
:
self
.
open_file
()
else
:
self
.
open_redis
()
def
open_memory
(
self
):
pass
cache_type
=
os
.
getenv
(
"MGE_FASTRUN_CACHE_TYPE"
)
if
cache_type
not
in
(
"FILE"
,
"MEMORY"
):
try
:
redis_config
=
self
.
get_redis_config
()
except
Exception
as
exc
:
get_logger
().
error
(
"failed to connect to cache server {!r}; try fallback to "
"in-file cache"
.
format
(
exc
)
)
else
:
self
.
add_config
(
"redis"
,
redis_config
,
"fastrun use redis cache"
,
"failed to connect to cache server"
,
)
if
cache_type
!=
"MEMORY"
:
path
=
self
.
get_cache_file
(
self
.
get_cache_dir
())
self
.
add_config
(
"in-file"
,
{
"path"
:
path
},
"fastrun use in-file cache in {}"
.
format
(
path
),
"failed to create cache file in {}"
.
format
(
path
),
)
self
.
add_config
(
"in-memory"
,
{},
"fastrun use in-memory cache"
,
"failed to create in-memory cache"
,
)
def
open_file
(
self
):
def
get_cache_dir
(
self
):
cache_dir
=
os
.
getenv
(
"MGE_FASTRUN_CACHE_DIR"
)
try
:
if
not
cache_dir
:
from
..hub.hub
import
_get_megengine_home
if
not
cache_dir
:
from
..hub.hub
import
_get_megengine_home
cache_dir
=
os
.
path
.
expanduser
(
os
.
path
.
join
(
_get_megengine_home
(),
"persistent_cache.bin"
)
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
cache_file
=
os
.
path
.
join
(
cache_dir
,
"cache"
)
with
open
(
cache_file
,
"a"
):
pass
assert
self
.
try_open_file
(
cache_file
),
"cannot create file"
get_logger
().
info
(
"fastrun use in-file cache in {}"
.
format
(
cache_dir
))
except
Exception
as
exc
:
get_logger
().
error
(
"failed to create cache file in {} {!r}; fallback to "
"in-memory cache"
.
format
(
cache_dir
,
exc
)
cache_dir
=
os
.
path
.
expanduser
(
os
.
path
.
join
(
_get_megengine_home
(),
"persistent_cache"
)
)
self
.
open_memory
()
def
open_redis
(
self
):
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
return
cache_dir
def
get_cache_file
(
self
,
cache_dir
):
cache_file
=
os
.
path
.
join
(
cache_dir
,
"cache.bin"
)
with
open
(
cache_file
,
"a"
):
pass
return
cache_file
@
contextlib
.
contextmanager
def
lock_cache_file
(
self
,
cache_dir
):
lock_file
=
os
.
path
.
join
(
cache_dir
,
"cache.lock"
)
with
filelock
.
FileLock
(
lock_file
):
yield
def
get_redis_config
(
self
):
url
=
os
.
getenv
(
"MGE_FASTRUN_CACHE_URL"
)
if
url
is
None
:
return
None
assert
sys
.
platform
!=
"win32"
,
"redis cache on windows not tested"
prefix
=
"mgbcache:{}:MGB{}:GIT:{}"
.
format
(
getpass
.
getuser
(),
__version__
,
git_version
)
url
=
os
.
getenv
(
"MGE_FASTRUN_CACHE_URL"
)
if
url
is
None
:
self
.
open_file
()
try
:
assert
sys
.
platform
!=
"win32"
,
"redis cache on windows not tested"
parse_result
=
urllib
.
parse
.
urlparse
(
url
,
scheme
=
"redis"
)
assert
parse_result
.
scheme
==
"redis"
,
"unsupported scheme"
assert
not
parse_result
.
username
,
"redis conn with username unsupported"
assert
self
.
try_open_redis
(
parse_result
.
hostname
,
parse_result
.
port
,
parse_result
.
password
,
prefix
),
"connect failed"
except
Exception
as
exc
:
get_logger
().
error
(
"
failed to connect to cache server {!r}; try fallback to "
"in-file cache"
.
format
(
exc
)
)
self
.
open_file
()
_manager
=
None
parse_result
=
urllib
.
parse
.
urlparse
(
url
)
assert
not
parse_result
.
username
,
"redis conn with username unsupported"
if
parse_result
.
scheme
==
"redis"
:
assert
parse_result
.
hostname
and
parse_result
.
port
,
"invalid url"
assert
not
parse_result
.
path
config
=
{
"hostname"
:
parse_result
.
hostname
,
"port"
:
str
(
parse_result
.
port
),
}
elif
parse_result
.
scheme
==
"redis+socket"
:
assert
not
(
parse_result
.
hostname
or
parse_result
.
port
)
assert
parse_result
.
path
config
=
{
"
unixsocket"
:
parse_result
.
path
,
}
else
:
assert
False
,
"unsupported scheme"
if
parse_result
.
password
is
not
None
:
config
[
"password"
]
=
parse_result
.
password
config
[
"prefix"
]
=
prefix
return
config
def
get_manager
():
global
_manager
if
_manager
is
None
:
_manager
=
PersistentCacheManager
()
return
_manager
def
flush
(
self
):
if
self
.
config
is
not
None
and
self
.
config
.
type
==
"in-file"
:
with
self
.
lock_cache_file
(
self
.
get_cache_dir
()):
super
().
flush
()
def
_clean
():
nr_del
=
get_manag
er
().
clean
()
nr_del
=
PersistentCacheOnServ
er
().
clean
()
if
nr_del
is
not
None
:
print
(
"{} cache entries deleted"
.
format
(
nr_del
))
...
...
imperative/python/requires.txt
浏览文件 @
1657b8e8
...
...
@@ -4,8 +4,8 @@ pyarrow
requests
tabulate
tqdm
redispy
deprecated
mprop
wheel
megfile>=0.0.10
\ No newline at end of file
megfile>=0.0.10
filelock
imperative/python/src/utils.cpp
浏览文件 @
1657b8e8
...
...
@@ -210,7 +210,7 @@ void init_utils(py::module m) {
.
def
(
"disable"
,
[](
TensorSanityCheck
&
checker
)
{
checker
.
disable
();
});
#if MGB_ENABLE_OPR_MM
m
.
def
(
"create_mm_server"
,
&
create_zmqrpc_server
,
py
::
arg
(
"addr"
),
m
.
def
(
"create_mm_server"
,
&
mgb
::
opr
::
create_zmqrpc_server
,
py
::
arg
(
"addr"
),
py
::
arg
(
"port"
)
=
0
);
#else
m
.
def
(
"create_mm_server"
,
[]()
{});
...
...
@@ -234,51 +234,108 @@ void init_utils(py::module m) {
using
ExtendedPersistentCache
=
mgb
::
imperative
::
persistent_cache
::
ExtendedPersistentCache
;
struct
PersistentCacheManager
{
std
::
shared_ptr
<
ExtendedPersistentCache
>
instance
;
struct
ConfigurablePersistentCache
:
mgb
::
PersistentCache
{
struct
Config
{
std
::
string
type
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
args
;
std
::
string
on_success
;
std
::
string
on_fail
;
};
bool
try_reg
(
std
::
shared_ptr
<
ExtendedPersistentCache
>
cache
)
{
if
(
cache
)
{
instance
=
cache
;
PersistentCache
::
set_impl
(
cache
);
return
true
;
}
return
false
;
}
bool
open_redis
(
std
::
string
ip
,
size_t
port
,
std
::
string
password
,
std
::
string
prefix
)
{
return
try_reg
(
mgb
::
imperative
::
persistent_cache
::
make_redis
(
ip
,
port
,
password
,
prefix
));
std
::
shared_ptr
<
ExtendedPersistentCache
>
impl
;
std
::
optional
<
Config
>
impl_config
;
std
::
vector
<
Config
>
configs
;
void
add_config
(
std
::
string
type
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
args
,
std
::
string
on_success
,
std
::
string
on_fail
)
{
configs
.
push_back
({
type
,
args
,
on_success
,
on_fail
});
}
bool
open_file
(
std
::
string
path
)
{
return
try_reg
(
mgb
::
imperative
::
persistent_cache
::
make_in_file
(
path
));
std
::
optional
<
size_t
>
clean
()
{
return
get_impl
()
->
clear
();
}
void
load_config
()
{
std
::
optional
<
std
::
string
>
err_msg
;
for
(
size_t
i
=
0
;
i
<
configs
.
size
();
++
i
)
{
auto
&
config
=
configs
[
i
];
if
(
err_msg
)
{
mgb_log_warn
(
"try fallback to %s cache"
,
config
.
type
.
c_str
());
}
else
{
err_msg
.
emplace
();
}
auto
cache
=
ExtendedPersistentCache
::
make_from_config
(
config
.
type
,
config
.
args
,
*
err_msg
);
if
(
!
cache
)
{
mgb_log_warn
(
"%s %s"
,
config
.
on_fail
.
c_str
(),
err_msg
->
c_str
());
}
else
{
impl
=
cache
;
impl_config
=
config
;
break
;
}
}
mgb_assert
(
impl_config
.
has_value
(),
"not valid config"
);
}
std
::
optional
<
size_t
>
clean
()
{
if
(
instance
)
{
return
instance
->
clear
();
std
::
shared_ptr
<
ExtendedPersistentCache
>
get_impl
()
{
if
(
!
impl
)
{
load_config
();
}
return
{}
;
return
impl
;
}
void
put
(
std
::
string
category
,
std
::
string
key
,
std
::
string
value
)
{
PersistentCache
::
inst
().
put
(
category
,
{
key
.
data
(),
key
.
size
()},
{
value
.
data
(),
value
.
size
()});
virtual
mgb
::
Maybe
<
Blob
>
get
(
const
std
::
string
&
category
,
const
Blob
&
key
)
{
return
get_impl
()
->
get
(
category
,
key
);
}
virtual
void
put
(
const
std
::
string
&
category
,
const
Blob
&
key
,
const
Blob
&
value
)
{
return
get_impl
()
->
put
(
category
,
key
,
value
);
}
py
::
object
get
(
std
::
string
category
,
std
::
string
key
)
{
auto
value
=
PersistentCache
::
inst
().
get
(
category
,
{
key
.
data
(),
key
.
size
()});
virtual
bool
support_dump_cache
()
{
return
get_impl
()
->
support_dump_cache
();
}
py
::
object
py_get
(
std
::
string
category
,
std
::
string
key
)
{
auto
value
=
get_impl
()
->
get
(
category
,
{
key
.
data
(),
key
.
size
()});
if
(
value
.
valid
())
{
return
py
::
bytes
(
std
::
string
((
const
char
*
)
value
->
ptr
,
value
->
size
));
}
else
{
return
py
::
none
();
}
}
void
py_put
(
std
::
string
category
,
std
::
string
key
,
std
::
string
value
)
{
get_impl
()
->
put
(
category
,
{
key
.
data
(),
key
.
size
()},
{
value
.
data
(),
value
.
size
()});
}
void
flush
()
{
if
(
impl
)
{
impl
->
flush
();
}
}
};
py
::
class_
<
PersistentCacheManager
>
(
m
,
"PersistentCacheManager"
)
.
def
(
py
::
init
<>
())
.
def
(
"try_open_redis"
,
&
PersistentCacheManager
::
open_redis
)
.
def
(
"try_open_file"
,
&
PersistentCacheManager
::
open_file
)
.
def
(
"clean"
,
&
PersistentCacheManager
::
clean
)
.
def
(
"put"
,
&
PersistentCacheManager
::
put
)
.
def
(
"get"
,
&
PersistentCacheManager
::
get
);
auto
PyConfigurablePersistentCache
=
py
::
class_
<
ConfigurablePersistentCache
,
std
::
shared_ptr
<
ConfigurablePersistentCache
>>
(
m
,
"PersistentCache"
)
.
def
(
py
::
init
<>
())
.
def
(
"add_config"
,
&
ConfigurablePersistentCache
::
add_config
)
.
def
(
"reg"
,
[](
std
::
shared_ptr
<
ConfigurablePersistentCache
>
inst
)
{
PersistentCache
::
set_impl
(
inst
);
})
.
def
(
"clean"
,
&
ConfigurablePersistentCache
::
clean
)
.
def
(
"get"
,
&
ConfigurablePersistentCache
::
py_get
)
.
def
(
"put"
,
&
ConfigurablePersistentCache
::
py_put
)
.
def_readonly
(
"config"
,
&
ConfigurablePersistentCache
::
impl_config
)
.
def
(
"flush"
,
&
ConfigurablePersistentCache
::
flush
);
py
::
class_
<
ConfigurablePersistentCache
::
Config
>
(
PyConfigurablePersistentCache
,
"Config"
)
.
def_readwrite
(
"type"
,
&
ConfigurablePersistentCache
::
Config
::
type
)
.
def_readwrite
(
"args"
,
&
ConfigurablePersistentCache
::
Config
::
args
)
.
def_readwrite
(
"on_fail"
,
&
ConfigurablePersistentCache
::
Config
::
on_fail
)
.
def_readwrite
(
"on_success"
,
&
ConfigurablePersistentCache
::
Config
::
on_success
);
}
imperative/src/impl/ops/collective_comm.cpp
浏览文件 @
1657b8e8
...
...
@@ -27,7 +27,7 @@ namespace imperative {
namespace
{
cg
::
OperatorNodeBase
*
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
comm
=
def
.
cast_final_safe
<
CollectiveComm
>
();
auto
group_client
=
std
::
make_shared
<
GroupClientProxy
>
(
auto
group_client
=
std
::
make_shared
<
opr
::
GroupClientProxy
>
(
ssprintf
(
"%s:%d"
,
comm
.
addr
.
data
(),
comm
.
port
));
SmallVector
<
std
::
shared_ptr
<
mgb
::
DeviceTensorND
>>
dev_buffer_arr
(
1
,
nullptr
);
auto
disable
=
std
::
make_shared
<
DTypeScalar
>
();
...
...
imperative/src/impl/ops/io_remote.cpp
浏览文件 @
1657b8e8
...
...
@@ -28,7 +28,7 @@ namespace {
cg
::
OperatorNodeBase
*
apply_on_var_node_remote_send
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
send
=
def
.
cast_final_safe
<
RemoteSend
>
();
auto
group_client
=
std
::
make_shared
<
GroupClientProxy
>
(
auto
group_client
=
std
::
make_shared
<
opr
::
GroupClientProxy
>
(
ssprintf
(
"%s:%d"
,
send
.
addr
.
data
(),
send
.
port
));
auto
&&
graph
=
inputs
[
0
]
->
owner_graph
();
...
...
@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
auto
&&
recv
=
def
.
cast_final_safe
<
RemoteRecv
>
();
OperatorNodeConfig
config
{
recv
.
cn
};
config
.
name
(
recv
.
make_name
());
auto
group_client
=
std
::
make_shared
<
GroupClientProxy
>
(
auto
group_client
=
std
::
make_shared
<
opr
::
GroupClientProxy
>
(
ssprintf
(
"%s:%d"
,
recv
.
addr
.
data
(),
recv
.
port
));
auto
&&
graph
=
inputs
[
0
]
->
owner_graph
();
return
graph
->
insert_opr
(
std
::
make_unique
<
mgb
::
opr
::
RemoteRecv
>
(
...
...
imperative/src/impl/persistent_cache.cpp
浏览文件 @
1657b8e8
...
...
@@ -27,8 +27,10 @@ public:
m_local
=
std
::
make_shared
<
mgb
::
InMemoryPersistentCache
>
();
}
bool
connect
(
std
::
string
ip
,
size_t
port
,
std
::
string
password
)
{
m_client
.
auth
(
password
);
void
connect
(
std
::
string
ip
,
size_t
port
,
std
::
optional
<
std
::
string
>
password
)
{
if
(
password
)
{
m_client
.
auth
(
*
password
);
}
m_client
.
connect
(
ip
,
port
,
[](
const
std
::
string
&
host
,
std
::
size_t
port
,
...
...
@@ -40,16 +42,32 @@ public:
}
},
std
::
uint32_t
(
200
));
if
(
!
m_client
.
is_connected
())
{
return
false
;
}
mgb_assert
(
m_client
.
is_connected
(),
"connect failed"
);
auto
flag
=
m_client
.
get
(
"mgb-cache-flag"
);
sync
();
return
flag
.
get
().
ok
();
auto
is_valid
=
[](
const
cpp_redis
::
reply
&
reply
)
{
switch
(
reply
.
get_type
())
{
case
cpp_redis
::
reply
::
type
::
error
:
case
cpp_redis
::
reply
::
type
::
null
:
return
false
;
case
cpp_redis
::
reply
::
type
::
integer
:
return
reply
.
as_integer
()
!=
0
;
case
cpp_redis
::
reply
::
type
::
simple_string
:
case
cpp_redis
::
reply
::
type
::
bulk_string
:
return
!
reply
.
as_string
().
empty
();
case
cpp_redis
::
reply
::
type
::
array
:
return
!
reply
.
as_array
().
empty
();
default:
mgb_assert
(
false
,
"unknown reply type %d"
,
(
int
)
reply
.
get_type
());
}
};
mgb_assert
(
is_valid
(
flag
.
get
()),
"invalid mgb-cache-flag"
);
}
bool
valid
()
const
override
{
return
m_client
.
is_connected
();
}
void
flush
()
override
{}
mgb
::
Maybe
<
Blob
>
get
(
const
std
::
string
&
category
,
const
Blob
&
key
)
override
{
MGB_LOCK_GUARD
(
m_mtx
);
auto
mem_result
=
m_local
->
get
(
category
,
key
);
...
...
@@ -75,7 +93,7 @@ public:
MGB_LOCK_GUARD
(
m_mtx
);
std
::
string
key_str
(
static_cast
<
const
char
*>
(
key
.
ptr
),
key
.
size
);
std
::
string
redis_key_str
;
encode
(
category
+
'@'
+
key_str
,
redis_key_str
);
encode
(
category
+
'@'
+
key_str
,
redis_key_str
,
24
);
std
::
string
value_str
(
static_cast
<
const
char
*>
(
value
.
ptr
),
value
.
size
);
std
::
string
redis_value_str
;
encode
(
value_str
,
redis_value_str
);
...
...
@@ -118,18 +136,16 @@ private:
class
ExtendedInFilePersistentCache
final
:
public
ExtendedPersistentCache
{
private:
std
::
string
m_path
;
std
::
optional
<
std
::
string
>
m_path
;
std
::
unique_ptr
<
mgb
::
InFilePersistentCache
>
m_impl
;
public:
ExtendedInFilePersistentCache
()
=
default
;
bool
open
(
std
::
string
path
)
{
void
open
(
std
::
string
path
)
{
std
::
fstream
file
;
file
.
open
(
path
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
if
(
!
file
.
is_open
())
{
return
false
;
}
mgb_assert
(
file
.
is_open
(),
"can't open file in %s"
,
path
.
c_str
());
std
::
vector
<
char
>
bytes
=
{
std
::
istreambuf_iterator
<
char
>
(
file
),
std
::
istreambuf_iterator
<
char
>
()};
if
(
bytes
.
size
())
{
...
...
@@ -139,14 +155,11 @@ public:
m_impl
=
std
::
make_unique
<
mgb
::
InFilePersistentCache
>
();
}
m_path
=
path
;
return
true
;
}
~
ExtendedInFilePersistentCache
()
{
if
(
m_impl
)
{
m_impl
->
dump_cache
(
m_path
.
c_str
());
}
}
void
open
()
{
m_impl
=
std
::
make_unique
<
mgb
::
InFilePersistentCache
>
();
}
~
ExtendedInFilePersistentCache
()
{
flush
();
}
mgb
::
Maybe
<
Blob
>
get
(
const
std
::
string
&
category
,
const
Blob
&
key
)
override
{
return
m_impl
->
get
(
category
,
key
);
...
...
@@ -157,29 +170,64 @@ public:
}
std
::
optional
<
size_t
>
clear
()
override
{
m_impl
=
std
::
make_unique
<
mgb
::
InFilePersistentCache
>
();
m_impl
->
dump_cache
(
m_path
.
c_str
());
if
(
m_impl
)
{
m_impl
=
std
::
make_unique
<
mgb
::
InFilePersistentCache
>
();
if
(
m_path
)
{
m_impl
->
dump_cache
(
m_path
->
c_str
());
}
}
return
{};
}
bool
valid
()
const
override
{
return
m_impl
!=
nullptr
;
}
};
std
::
shared_ptr
<
ExtendedPersistentCache
>
make_redis
(
std
::
string
ip
,
size_t
port
,
std
::
string
password
,
std
::
string
prefix
)
{
auto
cache
=
std
::
make_shared
<
RedisCache
>
(
prefix
,
100
);
if
(
!
cache
->
connect
(
ip
,
port
,
password
))
{
return
nullptr
;
void
flush
()
override
{
if
(
m_impl
&&
m_path
)
{
m_impl
->
dump_cache
(
m_path
->
c_str
());
}
}
return
cache
;
}
};
std
::
shared_ptr
<
ExtendedPersistentCache
>
make_in_file
(
std
::
string
path
)
{
auto
cache
=
std
::
make_shared
<
ExtendedInFilePersistentCache
>
();
if
(
!
cache
->
open
(
path
))
{
return
nullptr
;
std
::
shared_ptr
<
ExtendedPersistentCache
>
ExtendedPersistentCache
::
make_from_config
(
std
::
string
type
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
args
,
std
::
string
&
err_msg
)
{
try
{
if
(
type
==
"redis"
)
{
std
::
string
prefix
=
args
.
at
(
"prefix"
);
std
::
optional
<
std
::
string
>
password
=
args
.
count
(
"password"
)
?
args
.
at
(
"password"
)
:
std
::
optional
<
std
::
string
>
();
auto
cache
=
std
::
make_shared
<
RedisCache
>
(
prefix
,
100
);
if
(
args
.
count
(
"unixsocket"
))
{
std
::
string
unixsocket
=
args
.
at
(
"unixsocket"
);
cache
->
connect
(
unixsocket
,
0
,
password
);
}
else
{
std
::
string
ip
=
args
.
at
(
"hostname"
);
int
port
=
atoi
(
args
.
at
(
"port"
).
c_str
());
std
::
optional
<
std
::
string
>
password
=
args
.
count
(
"password"
)
?
args
.
at
(
"password"
)
:
std
::
optional
<
std
::
string
>
();
cache
->
connect
(
ip
,
port
,
password
);
}
return
cache
;
}
else
if
(
type
==
"in-file"
)
{
std
::
string
path
=
args
.
at
(
"path"
);
auto
cache
=
std
::
make_shared
<
ExtendedInFilePersistentCache
>
();
cache
->
open
(
path
);
return
cache
;
}
else
if
(
type
==
"in-memory"
)
{
auto
cache
=
std
::
make_shared
<
ExtendedInFilePersistentCache
>
();
cache
->
open
();
return
cache
;
}
else
{
mgb_assert
(
false
,
"persistent cache type %s unsupported"
,
type
.
c_str
());
}
}
catch
(
const
std
::
exception
&
exc
)
{
err_msg
=
exc
.
what
();
}
catch
(...)
{
err_msg
=
"unknown exception"
;
}
return
cache
;
return
nullptr
;
}
}
// namespace mgb::imperative::persistent_cache
...
...
imperative/src/include/megbrain/imperative/persistent_cache.h
浏览文件 @
1657b8e8
...
...
@@ -20,12 +20,12 @@ class ExtendedPersistentCache : public mgb::PersistentCache {
public:
virtual
bool
valid
()
const
=
0
;
virtual
std
::
optional
<
size_t
>
clear
()
=
0
;
};
std
::
shared_ptr
<
ExtendedPersistentCache
>
make_redis
(
std
::
string
ip
,
size_t
port
,
std
::
string
password
,
std
::
string
prefix
);
virtual
void
flush
()
=
0
;
std
::
shared_ptr
<
ExtendedPersistentCache
>
make_in_file
(
std
::
string
path
);
static
std
::
shared_ptr
<
ExtendedPersistentCache
>
make_from_config
(
std
::
string
type
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
args
,
std
::
string
&
err_msg
);
};
}
// namespace mgb::imperative::persistent_cache
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
imperative/src/test/collective_comm.cpp
浏览文件 @
1657b8e8
...
...
@@ -20,7 +20,7 @@ TEST(TestImperative, AllReduceBasic) {
REQUIRE_GPU
(
2
);
const
char
*
server_addr
=
"127.0.0.1"
;
uint32_t
port
=
3456
;
mgb_assert
(
create_zmqrpc_server
(
server_addr
,
port
)
>
0
);
mgb_assert
(
opr
::
create_zmqrpc_server
(
server_addr
,
port
)
>
0
);
HostTensorGenerator
<>
gen
;
CompNode
cn0
=
CompNode
::
load
(
"gpu0"
),
cn1
=
CompNode
::
load
(
"gpu1"
);
...
...
imperative/src/test/io_remote.cpp
浏览文件 @
1657b8e8
...
...
@@ -20,7 +20,7 @@ TEST(TestImperative, IORemote) {
REQUIRE_GPU
(
2
);
const
char
*
server_addr
=
"127.0.0.1"
;
uint32_t
port
=
4567
;
mgb_assert
(
create_zmqrpc_server
(
server_addr
,
port
)
>
0
);
mgb_assert
(
opr
::
create_zmqrpc_server
(
server_addr
,
port
)
>
0
);
HostTensorGenerator
<>
gen
;
CompNode
cn0
=
CompNode
::
load
(
"gpu0"
),
cn1
=
CompNode
::
load
(
"gpu1"
);
...
...
src/opr-mm/impl/mm_handler.cpp
浏览文件 @
1657b8e8
...
...
@@ -17,6 +17,9 @@
#include "megbrain/opr/zmq_rpc.h"
#include "mm_handler.pb.h"
using
namespace
mgb
;
using
namespace
opr
;
/* ======================== GroupServerProxy ========================== */
/*!
* A proxy that receives zmqrpc call, direct call to NCCL Manager
...
...
@@ -213,7 +216,7 @@ struct ServerInfo {
std
::
unique_ptr
<
ZmqRpc
::
ZmqRpcServer
>
server
;
};
int
create_zmqrpc_server
(
const
std
::
string
&
server_addr
,
int
port
)
{
int
mgb
::
opr
::
create_zmqrpc_server
(
const
std
::
string
&
server_addr
,
int
port
)
{
static
std
::
unordered_map
<
std
::
string
,
ServerInfo
>
addr2server
;
static
std
::
mutex
mtx
;
MGB_LOCK_GUARD
(
mtx
);
...
...
src/opr-mm/include/megbrain/opr/mm_handler.h
浏览文件 @
1657b8e8
...
...
@@ -16,8 +16,8 @@
#include "megbrain/opr/collective_comm.h"
#include "megbrain/opr/group_manager.h"
using
namespace
mgb
;
using
namespace
opr
;
namespace
mgb
{
namespace
opr
{
/*!
* Comm MM Client Proxy.
...
...
@@ -56,6 +56,9 @@ private:
int
create_zmqrpc_server
(
const
std
::
string
&
server_addr
,
int
port
);
}
// namespace opr
}
// namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/version.ld
浏览文件 @
1657b8e8
...
...
@@ -13,8 +13,6 @@ global:
base_exceptions*;
};
megcore*;
*GroupClientProxy*;
*create_zmqrpc_server*;
*custom*;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录