提交 1d7fceca 编写于 作者: M Megvii Engine Team

feat(mge/serialization): add map location

GitOrigin-RevId: 4b6d83365bf70ce8cd7d35f59a64df1997251f51
上级 9320bf92
......@@ -291,6 +291,16 @@ def current_grad_target(comp_graph):
return _detail._current_grad_target(comp_graph)
def add_device_map(map_location):
"""add map location while loading models"""
_detail.CompNode.cn_thread_local.__setattr__("map_location", map_location)
def del_device_map():
"""delete map location"""
_detail.CompNode.cn_thread_local.__delattr__("map_location")
def inter_graph_trans_var(dest_graph, src):
"""get the corresponding var of *src* in *dest_graph*; assuming
*dest_graph* is a copy of owner graph of *src*; usually used in callback of
......
......@@ -107,6 +107,17 @@ def get_device_count(device_type="xpu", warn=True):
return _mgb.CompNode._get_device_count(device_type.upper(), warn)
def parse_locator(device_name: str) -> tuple:
"""get the tensor locator expression by device name.
:param device_name: device name, like 'cpu0', 'gpu1' and 'xpux'
:type device_name: str
:return: (device_type, dev_num, stream_num)
"""
return _mgb.CompNode._parse_locator(device_name)
def set_mem_reserve_size(size):
"""set memory reserve size:
......
......@@ -8,7 +8,10 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
import megengine._internal as mgb
from ..utils.max_recursion_limit import max_recursion_limit
from .device import get_default_device
def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL):
......@@ -36,16 +39,90 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL):
pickle_module.dump(obj, f, pickle_protocol)
def load(f, pickle_module=pickle):
class dmap:
def __init__(self, map_location):
self.map_location = map_location
def __enter__(self):
mgb.add_device_map(self.map_location)
return self
def __exit__(self, type, value, traceback):
mgb.del_device_map()
def _get_callable_map_location(map_location):
if map_location is None:
def callable_map_location(state):
return str(get_default_device())
elif isinstance(map_location, str):
def callable_map_location(state):
return map_location
elif isinstance(map_location, dict):
locator_map = {}
for key, value in map_location.items():
locator_key = mgb.config.parse_locator(key)[:2]
locator_map[locator_key] = value
def callable_map_location(state):
orig = mgb.config.parse_locator(state)[:2]
if orig in locator_map.keys():
state = locator_map[orig]
return state
else:
assert callable(map_location), "map_location should be str, dict or function"
callable_map_location = map_location
return callable_map_location
def load(f, map_location=None, pickle_module=pickle):
r"""Load an object saved with save() from a file.
:type f: text file object
:param f: a string of file name or a text file object from which to load.
:type map_location: str, dict or a function specifying the map rules
:param map_location: Default: ``None``.
.. note::
map_location will change the logical locator when loading models,
avoiding tensors be loading on non-existent device. If you want to
add the mapping relationship between logical locator and physical
locator in runtime, please call :func:`mge.set_device_map()`
:type pickle_module:
:param pickle_module: Default: ``pickle``.
.. note::
If you will call :func:`mge.set_default_device()`, please do it
before :func:`mge.load()`.
Examples:
.. testcode:
import megengine as mge
mge.load('model.mge')
# Load all tensors based on logical location.
mge.load('model.mge', map_location='gpu0')
# Load all tensors onto the device: GPU0
mge.load('model.mge', map_location={'gpu0':'cpu0'})
# Load all tensors based on logical location, but 'GPU0' will be renamed to 'CPU0'
mge.load('model.mge', map_location=lambda dev: 'cpu0')
# Load all tensors onto the device" CPU0
"""
if isinstance(f, str):
with open(f, "rb") as fin:
return load(fin, pickle_module=pickle_module)
return pickle_module.load(f)
return load(fin, map_location=map_location, pickle_module=pickle_module)
map_location = _get_callable_map_location(map_location) # callable map_location
with dmap(map_location):
return pickle_module.load(f)
......@@ -28,6 +28,12 @@ class CompNode {
static CompNode load(const char* id);
%extend {
static std::vector<int> _parse_locator(const std::string &id) const {
auto logi = CompNode::Locator::parse(id);
return {
static_cast<int>(logi.type), logi.device, logi.stream,
};
}
static void _set_device_map(const std::string &type,
int from, int to) {
CompNode::Locator::set_device_map(
......@@ -86,7 +92,14 @@ class CompNode {
2: 'CPU'
}
cn_thread_local = threading.local()
"""used to save map location when calling :func:`mge.load()`"""
def __setstate__(self, state):
""":func:`mge.load()` and :func:`deepcopy()` call this function,
The latter will not produce the map_location attribute"""
if "map_location" in CompNode.cn_thread_local.__dict__.keys():
state = CompNode.cn_thread_local.map_location(state)
self.this = CompNode_load(state).this
def __eq__(self, rhs):
......
......@@ -35,6 +35,7 @@ void _init_bfloat16_types(PyObject *m); // implemented in bfloat16.cpp
%pythoncode %{
import numpy as np
import os
import threading
intb1 = _mgb.intb1
intb2 = _mgb.intb2
intb4 = _mgb.intb4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册