Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1d7fceca
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1d7fceca
编写于
5月 01, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/serialization): add map location
GitOrigin-RevId: 4b6d83365bf70ce8cd7d35f59a64df1997251f51
上级
9320bf92
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
115 addition
and
3 deletion
+115
-3
python_module/megengine/_internal/__init__.py
python_module/megengine/_internal/__init__.py
+10
-0
python_module/megengine/_internal/config.py
python_module/megengine/_internal/config.py
+11
-0
python_module/megengine/core/serialization.py
python_module/megengine/core/serialization.py
+80
-3
python_module/src/swig/comp_node.i
python_module/src/swig/comp_node.i
+13
-0
python_module/src/swig/mgb.i
python_module/src/swig/mgb.i
+1
-0
未找到文件。
python_module/megengine/_internal/__init__.py
浏览文件 @
1d7fceca
...
...
@@ -291,6 +291,16 @@ def current_grad_target(comp_graph):
return
_detail
.
_current_grad_target
(
comp_graph
)
def
add_device_map
(
map_location
):
"""add map location while loading models"""
_detail
.
CompNode
.
cn_thread_local
.
__setattr__
(
"map_location"
,
map_location
)
def
del_device_map
():
"""delete map location"""
_detail
.
CompNode
.
cn_thread_local
.
__delattr__
(
"map_location"
)
def
inter_graph_trans_var
(
dest_graph
,
src
):
"""get the corresponding var of *src* in *dest_graph*; assuming
*dest_graph* is a copy of owner graph of *src*; usually used in callback of
...
...
python_module/megengine/_internal/config.py
浏览文件 @
1d7fceca
...
...
@@ -107,6 +107,17 @@ def get_device_count(device_type="xpu", warn=True):
return
_mgb
.
CompNode
.
_get_device_count
(
device_type
.
upper
(),
warn
)
def
parse_locator
(
device_name
:
str
)
->
tuple
:
"""get the tensor locator expression by device name.
:param device_name: device name, like 'cpu0', 'gpu1' and 'xpux'
:type device_name: str
:return: (device_type, dev_num, stream_num)
"""
return
_mgb
.
CompNode
.
_parse_locator
(
device_name
)
def
set_mem_reserve_size
(
size
):
"""set memory reserve size:
...
...
python_module/megengine/core/serialization.py
浏览文件 @
1d7fceca
...
...
@@ -8,7 +8,10 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
pickle
import
megengine._internal
as
mgb
from
..utils.max_recursion_limit
import
max_recursion_limit
from
.device
import
get_default_device
def
save
(
obj
,
f
,
pickle_module
=
pickle
,
pickle_protocol
=
pickle
.
HIGHEST_PROTOCOL
):
...
...
@@ -36,16 +39,90 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL):
pickle_module
.
dump
(
obj
,
f
,
pickle_protocol
)
def
load
(
f
,
pickle_module
=
pickle
):
class
dmap
:
def
__init__
(
self
,
map_location
):
self
.
map_location
=
map_location
def
__enter__
(
self
):
mgb
.
add_device_map
(
self
.
map_location
)
return
self
def
__exit__
(
self
,
type
,
value
,
traceback
):
mgb
.
del_device_map
()
def
_get_callable_map_location
(
map_location
):
if
map_location
is
None
:
def
callable_map_location
(
state
):
return
str
(
get_default_device
())
elif
isinstance
(
map_location
,
str
):
def
callable_map_location
(
state
):
return
map_location
elif
isinstance
(
map_location
,
dict
):
locator_map
=
{}
for
key
,
value
in
map_location
.
items
():
locator_key
=
mgb
.
config
.
parse_locator
(
key
)[:
2
]
locator_map
[
locator_key
]
=
value
def
callable_map_location
(
state
):
orig
=
mgb
.
config
.
parse_locator
(
state
)[:
2
]
if
orig
in
locator_map
.
keys
():
state
=
locator_map
[
orig
]
return
state
else
:
assert
callable
(
map_location
),
"map_location should be str, dict or function"
callable_map_location
=
map_location
return
callable_map_location
def
load
(
f
,
map_location
=
None
,
pickle_module
=
pickle
):
r
"""Load an object saved with save() from a file.
:type f: text file object
:param f: a string of file name or a text file object from which to load.
:type map_location: str, dict or a function specifying the map rules
:param map_location: Default: ``None``.
.. note::
map_location will change the logical locator when loading models,
avoiding tensors be loading on non-existent device. If you want to
add the mapping relationship between logical locator and physical
locator in runtime, please call :func:`mge.set_device_map()`
:type pickle_module:
:param pickle_module: Default: ``pickle``.
.. note::
If you will call :func:`mge.set_default_device()`, please do it
before :func:`mge.load()`.
Examples:
.. testcode:
import megengine as mge
mge.load('model.mge')
# Load all tensors based on logical location.
mge.load('model.mge', map_location='gpu0')
# Load all tensors onto the device: GPU0
mge.load('model.mge', map_location={'gpu0':'cpu0'})
# Load all tensors based on logical location, but 'GPU0' will be renamed to 'CPU0'
mge.load('model.mge', map_location=lambda dev: 'cpu0')
# Load all tensors onto the device" CPU0
"""
if
isinstance
(
f
,
str
):
with
open
(
f
,
"rb"
)
as
fin
:
return
load
(
fin
,
pickle_module
=
pickle_module
)
return
pickle_module
.
load
(
f
)
return
load
(
fin
,
map_location
=
map_location
,
pickle_module
=
pickle_module
)
map_location
=
_get_callable_map_location
(
map_location
)
# callable map_location
with
dmap
(
map_location
):
return
pickle_module
.
load
(
f
)
python_module/src/swig/comp_node.i
浏览文件 @
1d7fceca
...
...
@@ -28,6 +28,12 @@ class CompNode {
static
CompNode
load
(
const
char
*
id
)
;
%
extend
{
static
std
::
vector
<
int
>
_parse_locator
(
const
std
::
string
&
id)
const
{
auto
logi
=
CompNode
::
Locator
::
parse
(
id
)
;
return
{
static_cast
<
int
>
(
logi
.
type
),
logi
.
device
,
logi
.
stream
,
}
;
}
static
void
_set_device_map
(
const
std
::
string
&
type,
int
from
,
int
to
)
{
CompNode
::
Locator
::
set_device_map
(
...
...
@@ -86,7 +92,14 @@ class CompNode {
2
:
'
CPU
'
}
cn_thread_local
=
threading
.
local
()
"""used to save map location when calling :func:`mge.load()`"""
def
__setstate__
(
self
,
state
)
:
""":func:`mge.load()` and :func:`deepcopy()` call this function,
The latter will not produce the map_location attribute"""
if
"map_location"
in
CompNode
.
cn_thread_local
.
__dict__
.
keys
()
:
state
=
CompNode
.
cn_thread_local
.
map_location
(
state
)
self
.
this
=
CompNode_load
(
state
)
.
this
def
__eq__
(
self
,
rhs
)
:
...
...
python_module/src/swig/mgb.i
浏览文件 @
1d7fceca
...
...
@@ -35,6 +35,7 @@ void _init_bfloat16_types(PyObject *m); // implemented in bfloat16.cpp
%
pythoncode
%
{
import
numpy
as
np
import
os
import
threading
intb1
=
_mgb
.
intb1
intb2
=
_mgb
.
intb2
intb4
=
_mgb
.
intb4
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录