From 1d7fcecab27f8c7ccc38ca9142a37fdb541d7be9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 1 May 2020 20:32:42 +0800 Subject: [PATCH] feat(mge/serialization): add map location GitOrigin-RevId: 4b6d83365bf70ce8cd7d35f59a64df1997251f51 --- python_module/megengine/_internal/__init__.py | 10 +++ python_module/megengine/_internal/config.py | 11 +++ python_module/megengine/core/serialization.py | 83 ++++++++++++++++++- python_module/src/swig/comp_node.i | 13 +++ python_module/src/swig/mgb.i | 1 + 5 files changed, 115 insertions(+), 3 deletions(-) diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index 13372392e..00e314e25 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -291,6 +291,16 @@ def current_grad_target(comp_graph): return _detail._current_grad_target(comp_graph) +def add_device_map(map_location): + """add map location while loading models""" + _detail.CompNode.cn_thread_local.__setattr__("map_location", map_location) + + +def del_device_map(): + """delete map location""" + _detail.CompNode.cn_thread_local.__delattr__("map_location") + + def inter_graph_trans_var(dest_graph, src): """get the corresponding var of *src* in *dest_graph*; assuming *dest_graph* is a copy of owner graph of *src*; usually used in callback of diff --git a/python_module/megengine/_internal/config.py b/python_module/megengine/_internal/config.py index 8fccc5d41..31f8ccdad 100644 --- a/python_module/megengine/_internal/config.py +++ b/python_module/megengine/_internal/config.py @@ -107,6 +107,17 @@ def get_device_count(device_type="xpu", warn=True): return _mgb.CompNode._get_device_count(device_type.upper(), warn) +def parse_locator(device_name: str) -> tuple: + """get the tensor locator expression by device name. + + :param device_name: device name, like 'cpu0', 'gpu1' and 'xpux' + :type device_name: str + + :return: (device_type, dev_num, stream_num) + """ + return _mgb.CompNode._parse_locator(device_name) + + def set_mem_reserve_size(size): """set memory reserve size: diff --git a/python_module/megengine/core/serialization.py b/python_module/megengine/core/serialization.py index d76225326..8c18a5343 100644 --- a/python_module/megengine/core/serialization.py +++ b/python_module/megengine/core/serialization.py @@ -8,7 +8,10 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import pickle +import megengine._internal as mgb + from ..utils.max_recursion_limit import max_recursion_limit +from .device import get_default_device def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL): @@ -36,16 +39,90 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL): pickle_module.dump(obj, f, pickle_protocol) -def load(f, pickle_module=pickle): +class dmap: + def __init__(self, map_location): + self.map_location = map_location + + def __enter__(self): + mgb.add_device_map(self.map_location) + return self + + def __exit__(self, type, value, traceback): + mgb.del_device_map() + + +def _get_callable_map_location(map_location): + if map_location is None: + + def callable_map_location(state): + return str(get_default_device()) + + elif isinstance(map_location, str): + + def callable_map_location(state): + return map_location + + elif isinstance(map_location, dict): + locator_map = {} + for key, value in map_location.items(): + locator_key = mgb.config.parse_locator(key)[:2] + locator_map[locator_key] = value + + def callable_map_location(state): + orig = mgb.config.parse_locator(state)[:2] + if orig in locator_map.keys(): + state = locator_map[orig] + return state + + else: + assert callable(map_location), "map_location should be str, dict or function" + callable_map_location = map_location + return callable_map_location + + +def load(f, map_location=None, pickle_module=pickle): r"""Load an object saved with save() from a file. :type f: text file object :param f: a string of file name or a text file object from which to load. + :type map_location: str, dict or a function specifying the map rules + :param map_location: Default: ``None``. + + .. note:: + + map_location will change the logical locator when loading models, + avoiding tensors be loading on non-existent device. If you want to + add the mapping relationship between logical locator and physical + locator in runtime, please call :func:`mge.set_device_map()` + :type pickle_module: :param pickle_module: Default: ``pickle``. + .. note:: + + If you will call :func:`mge.set_default_device()`, please do it + before :func:`mge.load()`. + + Examples: + + .. testcode: + + import megengine as mge + mge.load('model.mge') + # Load all tensors based on logical location. + mge.load('model.mge', map_location='gpu0') + # Load all tensors onto the device: GPU0 + mge.load('model.mge', map_location={'gpu0':'cpu0'}) + # Load all tensors based on logical location, but 'GPU0' will be renamed to 'CPU0' + mge.load('model.mge', map_location=lambda dev: 'cpu0') + # Load all tensors onto the device" CPU0 + """ if isinstance(f, str): with open(f, "rb") as fin: - return load(fin, pickle_module=pickle_module) - return pickle_module.load(f) + return load(fin, map_location=map_location, pickle_module=pickle_module) + + map_location = _get_callable_map_location(map_location) # callable map_location + + with dmap(map_location): + return pickle_module.load(f) diff --git a/python_module/src/swig/comp_node.i b/python_module/src/swig/comp_node.i index 528ebb3ac..2d11eeef3 100644 --- a/python_module/src/swig/comp_node.i +++ b/python_module/src/swig/comp_node.i @@ -28,6 +28,12 @@ class CompNode { static CompNode load(const char* id); %extend { + static std::vector _parse_locator(const std::string &id) const { + auto logi = CompNode::Locator::parse(id); + return { + static_cast(logi.type), logi.device, logi.stream, + }; + } static void _set_device_map(const std::string &type, int from, int to) { CompNode::Locator::set_device_map( @@ -86,7 +92,14 @@ class CompNode { 2: 'CPU' } + cn_thread_local = threading.local() + """used to save map location when calling :func:`mge.load()`""" + def __setstate__(self, state): + """:func:`mge.load()` and :func:`deepcopy()` call this function, + The latter will not produce the map_location attribute""" + if "map_location" in CompNode.cn_thread_local.__dict__.keys(): + state = CompNode.cn_thread_local.map_location(state) self.this = CompNode_load(state).this def __eq__(self, rhs): diff --git a/python_module/src/swig/mgb.i b/python_module/src/swig/mgb.i index df261cf38..a77698076 100644 --- a/python_module/src/swig/mgb.i +++ b/python_module/src/swig/mgb.i @@ -35,6 +35,7 @@ void _init_bfloat16_types(PyObject *m); // implemented in bfloat16.cpp %pythoncode %{ import numpy as np import os +import threading intb1 = _mgb.intb1 intb2 = _mgb.intb2 intb4 = _mgb.intb4 -- GitLab